In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

### model

In [100]:
class Args:
    data = '../ssl/data'
    resnet_model = 'resnet50'
    encoder_model = 'LSTM'
    decoder_model = 'LSTM'
    embed_size = 128
    rnn_hidden_size = 64
    output_size = 128
    rnn_n_layers = 2
    rnn_seq_len = 3
    device = 'cuda:0'
    use_cosine_similarity = True
    weight_decay = 1e-2
    log_dir = '../ssl/logs'
    

args = Args()

In [2]:
class EncoderResNet(nn.Module):
    def __init__(self, base_model, out_dim):
        super().__init__()
        self.resnet_dict = {"resnet18": models.resnet18(pretrained=False),
                            "resnet50": models.resnet50(pretrained=False), 
                            "resnet101": models.resnet101(pretrained=False), 
                            "resnet152": models.resnet152(pretrained=False), 
                            "resnext50_wide": models.resnext50_32x4d(pretrained=False), 
                            "resnext101_wide": models.resnext101_32x8d(pretrained=False), 
                             }

        resnet = self._get_basemodel(base_model)
        num_ftrs = resnet.fc.in_features

        self.features = nn.Sequential(*list(resnet.children())[:-2])
        self.pool = nn.Sequential(*list(resnet.children())[-2:-1])

        # projection MLP
        self.l1 = nn.Linear(num_ftrs, num_ftrs)
        self.l2 = nn.Linear(num_ftrs, out_dim)
        self.weights_init()

    def _get_basemodel(self, model_name):
        try:
            model = self.resnet_dict[model_name]
            print("Feature extractor:", model_name)
            return model
        except:
            raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")

    def forward(self, x):
        batch, seq, channel, height, width = x.shape
        x = x.view(-1, channel, height, width)
        f = self.features(x)
        h = self.pool(f)
        h = h.squeeze().view(batch, seq, -1)
        x = self.l1(h)
        x = F.relu(x)
        x = self.l2(x).view(batch, seq, -1)
        return f, x
    
    def weights_init(m):
        if isinstance(m, nn.Conv2d):
            xavier(m.weight.data)
            xavier(m.bias.data)

In [4]:
# from torchsummary import summary
# model = EncoderResNet('resnet50', 512).cuda()

In [5]:
# model(torch.Tensor(1, 2, 3, 416, 416).cuda())

In [7]:
# https://github.com/sthalles/SimCLR/blob/master/models/resnet_simclr.py
# https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

In [6]:
class Seq2seq(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, seq_len=3,
                 encoder_model='GRU', decoder_model='GRU'):
        super().__init__()
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.rnn_encoder_dict = {"RNN": nn.RNN(input_dim, hidden_dim, num_layers=num_layers),
                                 "LSTM": nn.LSTM(input_dim, hidden_dim, num_layers=num_layers),
                                 "GRU": nn.GRU(input_dim, hidden_dim, num_layers=num_layers)
                                 }

        self.rnn_decoder_dict = {"RNN": nn.RNN(hidden_dim, output_dim, num_layers=num_layers),
                                 "LSTM": nn.LSTM(hidden_dim, output_dim, num_layers=num_layers),
                                 "GRU": nn.GRU(hidden_dim, output_dim, num_layers=num_layers)
                                 }

        self.rnn_lr = nn.Linear(hidden_dim*num_layers, output_dim*num_layers)

        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_layers = num_layers

        self.seq_len = seq_len

        self.decoder_model = decoder_model

        self.encoder = self._get_encodermodel(encoder_model)
        self.decoder = self._get_decodermodel(decoder_model)
        
        self.weights_init()

    def _get_encodermodel(self, model_name):
        model = self.rnn_encoder_dict[model_name]
        print("RNN model:", model_name)
        return model

    def _get_decodermodel(self, model_name):
        model = self.rnn_decoder_dict[model_name]
        print("RNN model:", model_name)
        return model

    def initHidden(self, batch_size):
        if self.decoder_model == 'LSTM':
            return (torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=self.device),
                    torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=self.device))
        else:
            return torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=self.device)

    def initInput(self, batch_size):
        return torch.zeros(self.seq_len, batch_size, self.hidden_dim, device=self.device)
    
    def weights_init(m):
        if isinstance(m, (nn.RNN, nn.GRU, nn.LSTM)):
            xavier_normal(m.weight.data)
            xavier_uniform(m.bias.data)
    

    def forward(self, encoder_inputs):
        batch_size = encoder_inputs.size(1)
        encoder_hidden = self.initHidden(batch_size)

        encoder_outputs = torch.zeros(
            self.seq_len, batch_size, self.hidden_dim, device=self.device)

        input_length = encoder_inputs.size(0)
        for ei in range(input_length):
            encoder_output, encoder_hidden = self.encoder(
                encoder_inputs[ei].unsqueeze(0), encoder_hidden)
            encoder_outputs[ei] = encoder_output

        decoder_inputs = self.initInput(batch_size)

        if self.hidden_dim != self.output_dim:
            if self.decoder_model == 'LSTM':
                decoder_hidden = (self.rnn_lr(encoder_hidden[0].transpose_(0, 1).reshape(batch_size, -1)).reshape(-1, batch_size, self.output_dim), 
                                 self.rnn_lr(encoder_hidden[1].transpose_(0, 1).reshape(batch_size, -1)).reshape(-1, batch_size, self.output_dim))

            else:
                decoder_hidden = self.rnn_lr(encoder_hidden.transpose_(0, 1).reshape(batch_size, -1))
                decoder_hidden = decoder_hidden.reshape(-1, batch_size, self.output_dim)

        decoder_outputs = torch.zeros(
            self.seq_len, batch_size, self.output_dim, device=self.device)

        for di in range(input_length):
            decoder_output, decoder_hidden = self.decoder(
                decoder_inputs[di].unsqueeze(0), decoder_hidden)
            decoder_outputs[di] = decoder_output

        return decoder_outputs

In [125]:
from torch.nn.parallel import data_parallel

class CPCModel(nn.Module):
    
    def __init__(self, args):
        super().__init__()
        self.encoder_q = EncoderResNet(args.resnet_model, args.embed_size)
        self.encoder_k = EncoderResNet(args.resnet_model, args.embed_size)
        
        self.seq2seq = Seq2seq(args.embed_size, args.rnn_hidden_size, args.output_size, num_layers=args.rnn_n_layers, seq_len=args.rnn_seq_len,
                               encoder_model=args.encoder_model, decoder_model=args.decoder_model)
        
        self.device = self._get_device()
        
        
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient
        
        
        self.encoder_q.to(self.device)
        self.encoder_k.to(self.device)
        self.seq2seq.to(self.device)
        
        self.m = 0.99
        
    
    def encode(self, inputs):
        _, x = data_parallel(self.encoder_q, inputs)
        return x 
    
    
    def get_feature(self, inputs):
        f, _ = data_parallel(self.encoder_q, inputs)
        return f
        
    
    def _get_device(self):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print("Running on:", device)
        return device
        
    def encode_fixed(self, inputs):
        _, x = data_parallel(self.encoder_k, inputs)
        return x
    
    
    def forward(self, inputs):
        f, x = data_parallel(self.encoder_q, inputs.to(self.device))
        x = x.transpose_(1, 0)
        x = data_parallel(self.seq2seq, x)
        outputs = x.transpose_(1, 0)
        return outputs
    
    
    def update_encoder_k(self):
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    

In [102]:
model=CPCModel(args)

Feature extractor: resnet50
Feature extractor: resnet50
RNN model: LSTM
RNN model: LSTM
Running on: cuda


In [53]:
inputs = torch.Tensor(1, 2, 3, 256, 306)
targets = torch.Tensor(1, 2, 3, 256, 306)
model(inputs).shape

RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 11.17 GiB total capacity; 10.08 GiB already allocated; 704.00 KiB free; 817.11 MiB cached)

In [11]:
model.encode(inputs).shape

torch.Size([1, 3, 128])

In [12]:
model.get_negative(inputs).shape

torch.Size([1, 3, 128])

### Loss Function

In [117]:
class CPCLoss(torch.nn.Module):

    def __init__(self, args):
        super().__init__()

        self.device = args.device
        self.use_cosine = args.use_cosine_similarity
        self.loss_function = self._get_similarity_function(
            args.use_cosine_similarity)
        self.embed_size = args.embed_size

    def _get_similarity_function(self, use_cosine_similarity):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_similarity
        else:
            return self._dot_simililarity

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.view(-1, self.embed_size),
                            y.view(-1, self.embed_size), dims=2).to(self.device)
        return v

    def forward(self, x, y):
        return torch.sum(self.loss_function(x, y))

In [118]:
criterion = CPCLoss(args)

In [None]:
## True
#近的

model(x), y

In [None]:
## Negative 
#远的


encode(x), y

### data

In [13]:
from dataset_wrapper import DataSetWrapper

In [28]:
data = DataSetWrapper(batch_size=2, seq_length=3, num_workers=0, s=1)
train_loader = data.get_data_loader()

In [29]:
for i in train_loader:
    print(i[0].shape)
    print(i[1].shape)
    sample_batch = i
    break

torch.Size([2, 3, 3, 416, 416])
torch.Size([2, 3, 3, 416, 416])


In [32]:
model=CPCModel(args).cuda()
model(sample_batch[0].cuda()).shape

Feature extractor: resnet50
RNN model: LSTM
RNN model: LSTM


torch.Size([2, 3, 512])

In [33]:
torch.manual_seed(0)

<torch._C.Generator at 0x2b09ebf8f470>

### Run_train.py

In [126]:
class CPC_train(object):
    def __init__(self, dataset, args):
        self.args = args
        self.device = args.device
        self.dataset = dataset
        self.criterion = CPCLoss(args)
        self.model = CPCModel(args)

    def _step(self, x, y):
        y = self.model.encode_fixed(y)
        y = F.normalize(y, dim=2)
        
        y_hat = self.model(x)
        y_hat = F.normalize(y_hat, dim=2)

        loss = self.criterion(y_hat, y)
        
        x_hat = self.encode(x)
        x_hat = F.normalize(x_hat, dim=2)
        
        loss += self.criterion(x_hat, -y)
        return loss
    
    def train(self):
        train_loader = self.dataset.get_data_loader()
        model = self._load_pre_trained_weights(self.model)
        optimizer = torch.optim.Adam(model.parameters(), 3e-4, weight_decay=self.args.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
                                                               last_epoch=-1)
        
        model_checkpoints_folder = os.path.join(self.args.log_dir, 'checkpoints')
        
        _save_config_file(model_checkpoints_folder)
        
        
        
        

       
        