In [2]:
!pip3 install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

In [5]:
!pip install -U --no-cache-dir gdown --pre

In [6]:
! gdown --id 1enXTrapvY56RzA-F-ec9NoQ2XKpjPAT3

In [7]:
! unzip cifar10

In [17]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torch.utils.data.distributed
import random
import numpy as np
import itertools
import os
from tqdm import tqdm


class MoCo(nn.Module):
    def __init__(self):
        super(MoCo, self).__init__()
        torch.manual_seed(2287)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.q_backbone = torchvision.models.regnet_y_400mf(pretrained=False)
        self.k_backbone = torchvision.models.regnet_y_400mf(pretrained=False)
        self.transform_none = transforms.Compose([
            transforms.ToTensor(),
        ])

        self.transform_not_none = transforms.Compose([
            transforms.RandomHorizontalFlip(0.5),
            transforms.RandomRotation(degrees=180),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4),
            transforms.RandomGrayscale(p=0.1),
            transforms.ToTensor(),
        ])
        self.loss = nn.CrossEntropyLoss().to(self.device)
        self.queue = []

        self.train_loader_not_aug, self.train_loader_aug, self.queue_loader = self.get_train_loader()

        self.queue_loader = itertools.cycle(self.queue_loader)

    def q_forward(self, x):
        return self.q_backbone(x)

    def k_forward(self, x):
        return self.k_backbone(x)

    def train(self,
              t=0.1,
              lr=1e-4,
              weight_decay=0,
              batch_size=64,
              m=0.99,
              total_epoch=1000
              ):
        '''
        :param t: temperature for softmax
        :param lr:
        :param weight_decay:
        :param batch_size:
        :param m: momentum of query
        :param total_epoch:
        :return:
        '''
        if os.path.exists('q_encoder.pth'):
            self.q_backbone.load_state_dict(torch.load('q_encoder.pth'))
        if os.path.exists('k_encoder.pth'):
            self.k_backbone.load_state_dict(torch.load('k_encoder.pth'))
        for i in range(total_epoch):
            loss = self.train_one_epoch(t=t, lr=lr, weight_decay=weight_decay, batch_size=batch_size, m=m,)
            print(f'epoch {i + 1}, loss = {loss}')
            torch.save(self.q_backbone.state_dict(), 'q_encoder.pth')
            torch.save(self.k_backbone.state_dict(), 'k_encoder.pth')

    def train_one_epoch(self,
                        t=0.1,
                        lr=1e-3,
                        weight_decay=0,
                        batch_size=64,
                        m=0.99,
                        ):
        '''
        :param t: temperature for softmax
        :param lr:
        :param weight_decay:
        :param batch_size:
        :param m: momentum of query
        :return:
        '''
        optimizer = torch.optim.AdamW(self.q_backbone.parameters(), lr=lr, weight_decay=weight_decay)
        cosLR = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=len(self.train_loader_not_aug), eta_min=0)
        self.initialize_queue()
        epoch_loss = 0
        for (x_q, _), (x_k, _) in tqdm(zip(self.train_loader_not_aug, self.train_loader_aug)):
            optimizer.zero_grad()
            # x_q x_k both (N, D), D is the dimension of picture (C,H,W)

            q = self.q_forward(x_q.to(self.device))
            # q, k both (N, C)
            N, C = q.shape
            k = self.k_forward(x_k.to(self.device))
            k = k.detach()

            # Nx1
            l_pos = torch.bmm(q.view(N, 1, C), k.view(N, C, 1)).squeeze(2)
            # KxD
            queue = self.get_queue().detach()
            # print(queue.shape)
            # NxK
            l_neg = torch.mm(q.view(N, C), queue.permute(1, 0))

            # Nx(1+K)
            # print(l_pos.shape, l_neg.shape)
            logits = torch.cat([l_pos, l_neg], dim=1, )

            labels = torch.zeros(N, dtype=torch.long, device=self.device)
            loss = self.loss(logits / t, labels)
            loss.backward()
            epoch_loss += loss.item()
            optimizer.step()
            #cosLR.step()

            self.momentum_synchronize_backbone(m=m)
            self.enqueue(1)
            self.dequeue(1)

        epoch_loss /= len(self.train_loader_not_aug)
        return epoch_loss

    def get_train_loader(self, batch_size=64, seed=2287):
        train_set = datasets.CIFAR10('./data', train=True, transform=self.transform_none)
        torch.manual_seed(seed=seed)
        g = torch.Generator()
        train_loader_not_aug = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=g, )

        train_set = datasets.CIFAR10('./data', train=True, transform=self.transform_not_none)
        torch.manual_seed(seed=seed)
        g = torch.Generator()
        train_loader_aug = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=g)

        # without seed, so the order of queue loader is not same with those above
        train_set = datasets.CIFAR10('./data', train=True, transform=self.transform_none)
        queue_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

        return train_loader_not_aug, train_loader_aug, queue_loader

    def get_queue(self):
        '''
        :return: a tensor, (queue_length*batch, D)
        '''
        # print(self.queue)
        return torch.cat(self.queue, dim=0)

    def initialize_queue(self, queue_length=5, batch_size=64):
        '''
        :param queue_lenth: relative length in terms of batch_size.
        :param batch_size:
        :return:
        '''
        del self.queue
        self.queue = []
        self.enqueue(queue_length)

    def enqueue(self, k, ):
        '''
        :param k: relative length in terms of batch_size.
        :return:
        '''
        for i in range(k):
            x, _ = next(self.queue_loader)
            x = x.to(self.device)
            self.queue.append(self.k_backbone(x))

    def dequeue(self, k):
        '''
        :param k:
        :return: relative length in terms of batch_size.
        '''
        assert len(self.queue) > k, 'cant dequeue because queue are empty!!!'
        for i in range(k):
            self.queue.pop(0)

    def empty_queue(self):
        self.queue = None

    def momentum_synchronize_backbone(self, m):
        with torch.no_grad():
            for i, j in zip(self.q_backbone.parameters(), self.k_backbone.parameters()):
                j = m * j + (1 - m) * i


In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
a = MoCo().to(device)
a.train()

In [46]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from tqdm import tqdm
from torchvision import transforms


class Classifier(nn.Module):
    def __init__(self, encoder, classes=10):
        '''
        :param encoder: 'k_encoder' or 'q_encoder' or None
        '''
        super(Classifier, self).__init__()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.net = torchvision.models.regnet_y_400mf(pretrained=False)
        if encoder is not None:
            self.net.load_state_dict(torch.load(encoder + '.pth', map_location=device))
            #self.net.requires_grad_(False)
        else:
            pass
        self.fc = nn.Sequential(
            nn.LeakyReLU(),
            nn.Linear(1000, 200),
            nn.LeakyReLU(),
            nn.Linear(200,50),
            nn.LeakyReLU(),
            nn.Linear(50, classes),
        )



    def forward(self, x):
        x = self.net(x)
        return self.fc(x)


def get_train_loader(batch_size=64):
    train_set = datasets.CIFAR10('./data', train=True, transform=transforms.ToTensor())
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, )

    test_set = datasets.CIFAR10('./data', train=False, transform=transforms.ToTensor())
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, )
    return train_loader, test_loader


def train(batch_size=64, lr=1e-3, total_epoch=100, mode=None):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_loader, test_loader = get_train_loader(batch_size)
    model = Classifier(mode).to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    train_loss_for_draw = []
    train_acc_for_draw = []
    valid_loss_for_draw = []
    valid_acc_for_draw = []
    best_acc = 0
    best_loss = 999

    for epoch in range(total_epoch):

        train_loss = 0
        train_acc = 0
        for x, y in tqdm(train_loader):
            optimizer.zero_grad()
            x = x.to(device)
            y = y.to(device)

            # N, 10
            pre = model(x)
            loss = criterion(pre, y)
            train_loss += loss.item()
            loss.backward()
            optimizer.step()
            _, predict = torch.max(pre, dim=1)
            train_acc += (torch.sum((predict == y)).item() / batch_size)

        train_acc /= len(train_loader)
        train_loss /= len(train_loader)
        train_loss_for_draw.append(train_loss)
        train_acc_for_draw.append(train_acc)

        valid_loss = 0
        valid_acc = 0
        for x, y in tqdm(test_loader):
            optimizer.zero_grad()
            x = x.to(device)
            y = y.to(device)

            # N, 10
            pre = model(x)
            loss = criterion(pre, y)
            valid_loss += loss.item()
            loss.backward()
            _, predict = torch.max(pre, dim=1)
            valid_acc += (torch.sum((predict == y)).item() / batch_size)

        valid_acc /= len(test_loader)
        valid_loss /= len(test_loader)
        valid_loss_for_draw.append(valid_loss)
        valid_acc_for_draw.append(valid_acc)

        if valid_acc > best_acc:
            best_acc=valid_acc
            torch.save(model.state_dict(),'model.pth')

        if valid_loss<best_loss:
            best_loss=valid_loss

        print(f'epoch {epoch}, train loss = {train_loss}, train acc = {train_acc}, valid loss = {valid_loss}, valid acc = {valid_acc}')

    return valid_loss_for_draw, valid_acc_for_draw, best_acc, best_loss

In [None]:
none_loss, none_acc,a,b = train(mode = None, total_epoch =50)
print(a,b)

In [None]:
q_loss, q_acc, a, b = train(mode = 'q_encoder',total_epoch =50)
print(a,b)

In [None]:
k_loss, k_acc,a,b = train(mode = 'k_encoder', total_epoch =50)
print(a,b)

In [None]:
from matplotlib import pyplot as plt
import numpy as np

x = np.arange(10)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.plot(range(len(k_loss)), np.array(k_loss))
plt.plot(range(len(q_loss)), np.array(q_loss))
plt.plot(range(len(none_loss)), np.array(none_loss))

plt.legend(['k_encoder/momentum encoder','q_encoder','none'])
plt.show()

In [None]:
from matplotlib import pyplot as plt
import numpy as np

x = np.arange(10)
plt.xlabel('epoch')
plt.ylabel('acc')
plt.plot(range(len(k_acc)), np.array(k_acc))
plt.plot(range(len(q_acc)), np.array(q_acc))
plt.plot(range(len(none_acc)), np.array(none_acc))

plt.legend(['k_encoder/momentum encoder','q_encoder', 'none'])
plt.show()