In [1]:
import warnings
warnings.filterwarnings('ignore')
from glob import glob 
import os 

import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt 
from PIL import Image 
import cv2


import torch 
import torch.nn as nn 
from torch.utils.data import Dataset,DataLoader
import torchvision.transforms as transforms 
import torchvision
import timm 

from src.data.augmentation import *
from src.data.factory import create_dataset,create_dataloader
from src.options import Options

In [2]:
cfg = Options().parse()


trainset,testset = create_dataset(cfg)
train_loader = create_dataloader(trainset,
                                 cfg['Batchsize'],
                                 shuffle=True)
test_loader = create_dataloader(testset,
                                cfg['Batchsize'],
                                shuffle=False)

In [4]:
teacher = build_net(True).to(device)
student = build_net(False).to(device)


In [37]:
def build_net(pretrained=False):
    #net = timm.create_model('wide_resnet101_2',pretrained=pretrained)
    net = timm.create_model('resnet18',pretrained=pretrained)
    if pretrained:
        model = torch.nn.Sequential(*(list(net.children())[:-2]))
        for param in model.parameters():
            param.requires_grad = False
    else:
        model = torch.nn.Sequential(*(list(net.children())[:-2]))
        
    return model 

class Model(nn.Module):
    def __init__(self,training_type='deafult',device='cuda'):
        super(Model,self).__init__()
        self.teacher = build_net(True)
        self.student = build_net()
        self.training_type = training_type
    
    def train_independent_student(self,x):
        t_features = []
        s_features = []
        for (t_name,t_module),(s_name,s_module) in zip(self.teacher._modules.items(),self.student._modules.items()):
            if t_name in ['0','1','2','3']:
                x = t_module(x)
                #t_features.append(x)                    
            else:
                s = x.clone()
                x = t_module(x)
                s = s_module(s)
                
                t_features.append(x)
                s_features.append(s)
        return t_features,s_features
    
    def train_default_student(self,x):
        t_features = [] 
        s_features = []
        for (t_name,t_module),(s_name,s_module) in zip(self.teacher._modules.items(),self.student._modules.items()):
            if t_name == '0':
                x_s = s_module(x)
                x_t = t_module(x)
            else:
                x_s = s_module(x)
                x_t = t_module(x)
                if t_name in ['4','5','6','7']:
                    s_features.append(x_s)
                    t_features.append(x_t)
        return t_features,s_features
            
        
    def forward(self,x):
        if self.training_type =='default':
            t_features,s_features = self.train_default_student(x)
        else: 
            t_features,s_features = self.train_independent_student(x)
        return t_features,s_features

In [43]:
device = 'cuda:1'
img,msk = next(iter(train_loader))
img = img.type(torch.float32).to(device)
model = Model('independent')
model = model.to(device)
t_features,s_features = model(img)

In [44]:
i = 0 
t_f = t_features[i]
t_s = s_features[i]

In [48]:
criterion = nn.MSELoss()
loss = criterion(t_f,t_s)
loss.backward()