In [18]:
import sys
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

In [19]:
sys.path.append("..")

In [20]:
from models.fpn import FPN

In [43]:
def get_encoder(pretrained: str = "random"):
    
    if pretrained == "random":
        encoder = torchvision.models.resnet34(weights=None)
    elif pretrained == "imagenet":
        encoder = torchvision.models.resnet34(pretrained=True)
#     elif pretrained == "unsupervised":
#         encoder = get_unsupervised_pretrained_encoder()

    return encoder, [512, 512, 256, 128, 64],

class SemiSupervisedNetwork(nn.Module):
    
    def __init__(self, training: bool = True):
        super().__init__()
        
        self.training = training
        self.ema_decay = 0.99

        self.b1 = FPN(pretrained="random", classes=1)
        self.b2 = FPN(pretrained="random", classes=1)
        
        for param in self.b2.parameters():
            param.detach_()
            
        for t_param, s_param in zip(self.b2.parameters(), self.b1.parameters()):
            t_param.data.copy_(s_param.data)
        
    def forward(self, x, update_w: bool = False):
        
        if not self.training:
            pred = self.b1(x)
            return pred
          
        s_out = self.b1(x)
        
        with torch.no_grad():
            t_out = self.b2(x)
            
        if update_w:
            self._update_ema_variables(self.ema_decay) 
            
        return s_out, t_out
    
    def _update_ema_variables(self, ema_decay):
        for t_param, s_param in zip(self.b2.parameters(), self.b1.parameters()):
            t_param.data.mul_(ema_decay).add_(1 - ema_decay, s_param.data)
        

In [44]:
m = SemiSupervisedNetwork()

In [45]:
# imgs = minibatch['data']
# masks = minibatch['mask']
# uimgs = minibatch['udata']

imgs = torch.rand((2, 3, 256, 256))
masks = torch.rand((2, 1, 256, 256))
uimgs = torch.rand((2, 3, 256, 256))

In [46]:
sup_loss_fn = nn.BCEWithLogitsLoss()
unsup_loss_fn = nn.MSELoss(reduction='mean')

In [54]:
spreds, tpreds = m(imgs, update_w = True)
sunpreds, tunpreds = m(uimgs, update_w = False)

s_pred = torch.cat([spreds, sunpreds], dim=0)
t_pred = torch.cat([tpreds, tunpreds], dim=0)

loss_unsup = unsup_loss_fn(
    F.sigmoid(s_pred).round(),
    F.sigmoid(t_pred).round().detach())

# supervised loss
loss_sup = sup_loss_fn(spreds, masks)
loss_sup, loss_unsup

(tensor(0.9371, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
 tensor(0.2995, grad_fn=<MseLossBackward0>))