In [85]:
import numpy as np
import pandas as pd
import torchvision.models as models
import torchvision
import torch.nn as nn
import torch

from data_loader import *

In [140]:
def flatten(student, teacher):
    '''
    Function used to get BasicBlocks from ResNet class model
    '''
    student_layers = [student.layer1, student.layer2, student.layer3, student.layer4]
    teacher_layers = [teacher.layer1, teacher.layer2, teacher.layer3, teacher.layer4]

    student_blocks = []
    teacher_blocks = []
    
    for i in range(len(student_layers)):
        teacher_blocks += list(np.array_split(teacher_layers[i], len(student_layers[i]))) # divide teacher blocks into n list, where n is number of student blocks
        student_blocks += [el for el in student_layers[i]]

    return student_blocks, teacher_blocks

In [142]:
student, teacher = flatten(resnet18, resnet34)

In [164]:
def forward(x, student, teacher, p):
    '''
    Forward function for hybrid ResNet 
    '''
    def _forward_blocks(x, student_blocks, teacher_blocks, p):
        '''
        Forward function containing only hybrid blocks predicitons
        '''

        len_teacher_blocks = len(teacher_blocks)
        len_student_blocks = len(student_blocks)


        assert len_teacher_blocks == len_student_blocks   # check if size of blocks is the same
        
        a_all = [np.random.binomial(1, p) for i in range(len_student_blocks)]   # hybrid block building schema 

        tmp_x = x
        for i in range(len_student_blocks): # hybrid block
            print(i)
            if a_all[i] == 1: # student path
                tmp_x = student_blocks[i].forward(tmp_x)

            if a_all[i] == 0: # teacher path
                for j in range(len(teacher_blocks[i])):
                    tmp_x = teacher_blocks[i][j].forward(tmp_x)

        return tmp_x, a_all
    
    student_blocks, teacher_blocks = flatten(student, teacher)

    tmp_x = x     # forward pipeline
    tmp_x = student.conv1(tmp_x)
    tmp_x = student.bn1(tmp_x)
    tmp_x = student.relu(tmp_x)
    tmp_x = student.maxpool(tmp_x)
    tmp_x, a_all = _forward_blocks(tmp_x, student_blocks, teacher_blocks, p)
    tmp_x = student.avgpool(tmp_x)
    tmp_x = torch.flatten(tmp_x, 1)
    output = student.fc(tmp_x)

    return output, a_all


In [138]:
resnet34 = models.resnet34(pretrained=False)
resnet18 = models.resnet18(pretrained=False)

In [165]:
forward(images, resnet18, resnet34, 0.5)

0
1
2
3
4
5
6
7


(tensor([[-0.3836,  0.5225,  0.8997,  ...,  0.3284, -0.3004, -1.0443],
         [-0.4587,  0.0323,  0.6330,  ...,  0.6546, -0.2758, -0.7741],
         [-0.4359,  0.1020,  0.5935,  ...,  0.4528, -0.3445, -0.7543],
         ...,
         [-0.2938,  0.1988,  0.6536,  ...,  0.2726, -0.2325, -0.7988],
         [-0.3176, -0.1241,  0.4656,  ...,  0.4756, -0.4430, -0.8032],
         [-0.4971, -0.0197,  0.4594,  ...,  0.5301, -0.3253, -0.7254]],
        grad_fn=<AddmmBackward0>),
 [0, 0, 1, 1, 0, 0, 1, 0])