In [None]:
# default_exp model

# Model

In [None]:
#hide

from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#export

import torch
import torch.nn as nn
from fastai.basics import *
from fastai.vision.all import *

In [None]:
#export

class BaselineSTM(Module):
    def __init__(self, arch, n_out, pretrained=True):
        store_attr()
        self.encoder = TimeDistributed(create_body(arch, pretrained=pretrained))
        n_features = dummy_eval(self.encoder.module, (224, 224)).shape[1]
        self.head = TimeDistributed(create_head(n_features, n_out))
    
    def forward(self, x):
        feature_map = self.encoder(torch.stack(x, dim=1))
        return self.head(feature_map).mean(dim=1)
   
    @staticmethod
    def splitter(model): 
        return [params(model.encoder), params(model.head)]

In [None]:
#export

class BaselineMTM(Module):
    def __init__(self, arch, n_distortion, n_sev, pretrained=True):
        store_attr()
        self.encoder = TimeDistributed(create_body(arch, pretrained=pretrained))
        n_features = dummy_eval(self.encoder.module, (224, 224)).shape[1]
        self.head = TimeDistributed(create_head(n_features, n_distortion + n_sev))
    
    def forward(self, x):
        feature_map = self.encoder(torch.stack(x, dim=1))
        out = self.head(feature_map).mean(dim=1)
        return [out[:, :self.n_distortion], out[:, self.n_distortion:]]
   
    @staticmethod
    def splitter(model): 
        return [params(model.encoder), params(model.head)]

In [None]:
#export

class MultiScaleBackbone(Module):
    def __init__(self, arch, pretrained=True):
        store_attr()
        self.backbone = create_body(arch, pretrained=pretrained)
        self.hooks = hook_outputs(list(self.backbone.children())[4:-1], detach=False)
    
    def forward(self, x):
        feat_map_last = self.backbone(x)
        feat_maps = self.hooks.stored
        return torch.cat([AdaptiveConcatPool2d()(fm) for fm in [*feat_maps, feat_map_last]], dim=1)
    
class MultiScaleMTM(Module):
    def __init__(self, arch, n_distortion, n_sev, pretrained=True):
        store_attr()
        self.encoder = TimeDistributed(MultiScaleBackbone(arch, pretrained=pretrained))
        n_features = dummy_eval(self.encoder.module, (224, 224)).shape[1]
        self.head = TimeDistributed(create_head(n_features, n_distortion + n_sev))
    
    def forward(self, x):
        x = torch.stack(x, dim=1)
        feature_map = self.encoder(x)
        out = self.head(feature_map).mean(dim=1)
        return [out[:, :self.n_distortion], out[:, self.n_distortion:]]
   
    @staticmethod
    def splitter(model): 
        return [params(model.encoder), params(model.head)]

In [None]:
#export

class SequenceSTM(Module):
    def __init__(self, arch, n_out, num_rnn_layers=1, pretrained=True):
        store_attr()
        self.encoder = TimeDistributed(nn.Sequential(
            create_body(arch, pretrained=pretrained), 
            nn.AdaptiveAvgPool2d(1), 
            Flatten()
        ))
        n_features = dummy_eval(self.encoder.module, (224, 224)).shape[1]
        self.rnn = nn.LSTM(n_features, n_features, num_layers=num_rnn_layers, batch_first=True)
        self.head = create_head(num_rnn_layers * n_features, n_out)[2:]
    
    def forward(self, x):
        x = self.encoder(torch.stack(x, dim=1))
        bs = x.shape[0]
        _, (h, _) = self.rnn(x)
        return self.head(h.view(bs, -1))
    
    @staticmethod
    def splitter(model):
        return [params(model.encoder), params(model.rnn) + params(model.head)]

In [None]:
#export

class SequenceMTM(Module):
    def __init__(self, arch, n_distortion, n_sev, num_rnn_layers=1, pretrained=True):
        store_attr()
        self.encoder = TimeDistributed(nn.Sequential(
            create_body(arch, pretrained=pretrained), 
            nn.AdaptiveAvgPool2d(1), 
            Flatten()
        ))
        n_features = dummy_eval(self.encoder.module, (224, 224)).shape[1]
        self.rnn = nn.LSTM(n_features, n_features, num_layers=num_rnn_layers, batch_first=True)
        self.head = create_head(num_rnn_layers * n_features, n_out)[2:]
    
    def forward(self, x):
        x = self.encoder(torch.stack(x, dim=1))
        bs = x.shape[0]
        _, (h, _) = self.rnn(x)
        out = self.head(h.view(bs,-1))
        return out[:, :self.n_distortion], out[:, self.n_distortion:]
    
    @staticmethod
    def splitter(model):
        return [params(model.encoder), params(model.rnn) + params(model.head)]

In [None]:
#export

class MultiHeadMTM(Module):
    def __init__(self, arch, n_distortion, n_sev, pretrained=True):
        store_attr()
        self.encoder = TimeDistributed(create_body(arch, pretrained=pretrained))
        n_features = dummy_eval(self.encoder.module, (224, 224)).shape[1]
        self.common_head = TimeDistributed(nn.Sequential(create_head(n_features, n_features), nn.ReLU()))
        self.dis_head = TimeDistributed(LinBnDrop(n_features, n_distortion))
        self.sev_head = TimeDistributed(LinBnDrop(n_features, n_sev))
    
    def forward(self, x):
        feature_map = self.encoder(torch.stack(x, dim=1))
        h = self.common_head(feature_map)
        out_dis = self.dis_head(h).mean(dim=1)
        out_sev = self.sev_head(h).mean(dim=1)
        return [out_dis, out_sev]
   
    @staticmethod
    def splitter(model): 
        return [params(model.encoder), params(model.common_head) + params(model.dis_head) + params(model.sev_head)]

In [None]:
#hide
bs = 2
n = 5
ndis, nsev = 18, 4
mhmtm = MultiHeadMTM(resnet18, ndis, nsev, False)
x = [torch.rand(bs, 3, 224, 224) for i in range(n)]
y1, y2 = mhmtm(x)
assert y1.shape == torch.Size([bs, ndis])
assert y2.shape == torch.Size([bs, nsev])