In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

In [2]:
class EncoderResNet(nn.Module):
    def __init__(self, base_model, out_dim):
        super().__init__()
        self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
                            "resnet50": models.resnet50(pretrained=False), 
                            "resnet101": models.resnet101(pretrained=False), 
                            "resnet152": models.resnet152(pretrained=False), 
                            "resnext50_wide": models.resnext50_32x4d(pretrained=False), 
                            "resnext101_wide": models.resnext101_32x8d(pretrained=False), 
                             }

        resnet = self._get_basemodel(base_model)
        num_ftrs = resnet.fc.in_features

        self.features = nn.Sequential(*list(resnet.children())[:-1])

        # projection MLP
        self.l1 = nn.Linear(num_ftrs, num_ftrs)
        self.l2 = nn.Linear(num_ftrs, out_dim)
        self.weights_init()

    def _get_basemodel(self, model_name):
        try:
            model = self.resnet_dict[model_name]
            print("Feature extractor:", model_name)
            return model
        except:
            raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")

    def forward(self, x):
        batch, seq, channel, height, width = x.shape
        x = x.view(-1, channel, height, width)
        h = self.features(x)
        h = h.squeeze().view(batch, seq, -1)
        x = self.l1(h)
        x = F.relu(x)
        x = self.l2(x).view(batch, seq, -1)
        return h, x
    
    def weights_init(m):
        if isinstance(m, nn.Conv2d):
            xavier(m.weight.data)
            xavier(m.bias.data)

In [3]:
from torchsummary import summary
model = EncoderResNet('resnet50', 512).cuda()

Feature extractor: resnet50


In [4]:
summary(model=model, input_size=(batch, seq, 3, 416, 416))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 208, 208]           9,408
       BatchNorm2d-2         [-1, 64, 208, 208]             128
              ReLU-3         [-1, 64, 208, 208]               0
         MaxPool2d-4         [-1, 64, 104, 104]               0
            Conv2d-5         [-1, 64, 104, 104]           4,096
       BatchNorm2d-6         [-1, 64, 104, 104]             128
              ReLU-7         [-1, 64, 104, 104]               0
            Conv2d-8         [-1, 64, 104, 104]          36,864
       BatchNorm2d-9         [-1, 64, 104, 104]             128
             ReLU-10         [-1, 64, 104, 104]               0
           Conv2d-11        [-1, 256, 104, 104]          16,384
      BatchNorm2d-12        [-1, 256, 104, 104]             512
           Conv2d-13        [-1, 256, 104, 104]          16,384
      BatchNorm2d-14        [-1, 256, 1

(tensor(28753472), tensor(28753472))

In [5]:
# https://github.com/sthalles/SimCLR/blob/master/models/resnet_simclr.py
# https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

In [6]:
class Seq2seq(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, seq_len=3,
                 encoder_model='GRU', decoder_model='GRU'):
        super().__init__()
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.rnn_encoder_dict = {"RNN": nn.RNN(input_dim, hidden_dim, num_layers=num_layers),
                                 "LSTM": nn.LSTM(input_dim, hidden_dim, num_layers=num_layers),
                                 "GRU": nn.GRU(input_dim, hidden_dim, num_layers=num_layers)
                                 }

        self.rnn_decoder_dict = {"RNN": nn.RNN(hidden_dim, output_dim, num_layers=num_layers),
                                 "LSTM": nn.LSTM(hidden_dim, output_dim, num_layers=num_layers),
                                 "GRU": nn.GRU(hidden_dim, output_dim, num_layers=num_layers)
                                 }

        self.rnn_lr = nn.Linear(hidden_dim*num_layers, output_dim*num_layers)

        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_layers = num_layers

        self.seq_len = seq_len

        self.decoder_model = decoder_model

        self.encoder = self._get_encodermodel(encoder_model)
        self.decoder = self._get_decodermodel(decoder_model)
        
        self.weights_init()

    def _get_encodermodel(self, model_name):
        model = self.rnn_encoder_dict[model_name]
        print("RNN model:", model_name)
        return model

    def _get_decodermodel(self, model_name):
        model = self.rnn_decoder_dict[model_name]
        print("RNN model:", model_name)
        return model

    def initHidden(self, batch_size):
        if self.decoder_model == 'LSTM':
            return (torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=self.device),
                    torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=self.device))
        else:
            return torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=self.device)

    def initInput(self, batch_size):
        return torch.zeros(self.seq_len, batch_size, self.hidden_dim, device=self.device)
    
    def weights_init(m):
        if isinstance(m, (nn.RNN, nn.GRU, nn.LSTM)):
            xavier_normal(m.weight.data)
            xavier_uniform(m.bias.data)
    

    def forward(self, encoder_inputs):
        batch_size = encoder_inputs.size(1)
        encoder_hidden = self.initHidden(batch_size)

        encoder_outputs = torch.zeros(
            self.seq_len, batch_size, self.hidden_dim, device=self.device)

        input_length = encoder_inputs.size(0)
        for ei in range(input_length):
            encoder_output, encoder_hidden = self.encoder(
                encoder_inputs[ei].unsqueeze(0), encoder_hidden)
            encoder_outputs[ei] = encoder_output

        decoder_inputs = self.initInput(batch_size)

        if self.hidden_dim != self.output_dim:
            if self.decoder_model == 'LSTM':
                decoder_hidden = (self.rnn_lr(encoder_hidden[0].transpose_(0, 1).reshape(batch_size, -1)).reshape(-1, batch_size, self.output_dim), 
                                 self.rnn_lr(encoder_hidden[1].transpose_(0, 1).reshape(batch_size, -1)).reshape(-1, batch_size, self.output_dim))

            else:
                decoder_hidden = self.rnn_lr(encoder_hidden.transpose_(0, 1).reshape(batch_size, -1))
                decoder_hidden = decoder_hidden.reshape(-1, batch_size, self.output_dim)

        decoder_outputs = torch.zeros(
            self.seq_len, batch_size, self.output_dim, device=self.device)

        for di in range(input_length):
            decoder_output, decoder_hidden = self.decoder(
                decoder_inputs[di].unsqueeze(0), decoder_hidden)
            decoder_outputs[di] = decoder_output

        return decoder_outputs

In [7]:
class Args:
    data = '../ssl/data'
    resnet_model = 'resnet50'
    encoder_model = 'LSTM'
    decoder_model = 'LSTM'
    embed_size = 512
    rnn_hidden_size = 128
    output_size = 512
    rnn_n_layers = 2
    rnn_seq_len = 3
    

args = Args()

In [8]:
class CPCModel(nn.Module):
    
    def __init__(self, args):
        super().__init__()
        self.Encoder = EncoderResNet(args.resnet_model, args.embed_size)
        self.Seq2Seq = Seq2seq(args.embed_size, args.rnn_hidden_size, args.output_size, num_layers=args.rnn_n_layers, seq_len=args.rnn_seq_len,
                               encoder_model=args.encoder_model, decoder_model=args.decoder_model)
        
        
    def forward(self, inputs):
        reprentation, x = self.Encoder(inputs)
        x = x.transpose_(1, 0)
        outputs = self.Seq2Seq(x)
        return outputs

In [9]:
model=CPCModel(args).cuda()
inputs = torch.Tensor(10, 3, 3, 256, 306).cuda()
model(inputs)

Feature extractor: resnet50
RNN model: LSTM
RNN model: LSTM


tensor([[[-0.0002,  0.0246, -0.0102,  ..., -0.0081,  0.0022,  0.0473],
         [-0.0039,  0.0143,  0.0126,  ..., -0.0223, -0.0217, -0.0057],
         [-0.0002,  0.0246, -0.0102,  ..., -0.0081,  0.0022,  0.0473],
         ...,
         [-0.0039,  0.0143,  0.0126,  ..., -0.0223, -0.0217, -0.0057],
         [-0.0002,  0.0246, -0.0102,  ..., -0.0081,  0.0022,  0.0473],
         [-0.0039,  0.0143,  0.0126,  ..., -0.0223, -0.0217, -0.0057]],

        [[-0.0134,  0.0179, -0.0102,  ..., -0.0199, -0.0039,  0.0205],
         [-0.0206,  0.0111,  0.0023,  ..., -0.0310, -0.0235, -0.0046],
         [-0.0134,  0.0179, -0.0102,  ..., -0.0199, -0.0039,  0.0205],
         ...,
         [-0.0206,  0.0111,  0.0023,  ..., -0.0310, -0.0235, -0.0046],
         [-0.0134,  0.0179, -0.0102,  ..., -0.0199, -0.0039,  0.0205],
         [-0.0206,  0.0111,  0.0023,  ..., -0.0310, -0.0235, -0.0046]],

        [[-0.0218,  0.0134, -0.0098,  ..., -0.0262, -0.0077,  0.0063],
         [-0.0286,  0.0100, -0.0031,  ..., -0

In [72]:
from torchsummary import summary

model = Seq2seq(2048, 2048, 2048).cuda()

RNN model: GRU
RNN model: GRU


In [73]:
summary(model=model, input_size=[(3, 2048), (3, 2048)])

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
               GRU-1  [[-1, 3, 2048], [-1, 3, 2048]]               0
               GRU-2  [[-1, 3, 2048], [-1, 3, 2048]]               0
           Seq2seq-3  [[-1, 3, 2048], [-1, 3, 2048]]               0
Total params: 0
Trainable params: 0
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 144.00
Forward/backward pass size (MB): 864.00
Params size (MB): 0.00
Estimated Total Size (MB): 1008.00
----------------------------------------------------------------



(0, 0)

In [75]:
a.shape

torch.Size([5, 1, 2048])

In [76]:
b.shape

torch.Size([2, 1, 2048])