참고 : https://github.com/smartdanny/MoCoV2_CIFAR10

## 1. Data

In [1]:
import torch
import torch.nn as nn
import torchvision
import torchmetrics  # pytorch_lightning 버전이 업데이트됨에 따라 metrics 메서드가 삭제됨
import pytorch_lightning as pl
import lightly

# moco_model.py : https://github.com/smartdanny/MoCoV2_CIFAR10
# moco_model.py에서 torchmetrics를 import한 후, Classifier부분의 self.accuracy=pl.metrics.Accuracy()를 torchmetrics.accuracy()로 변경
import moco_model  

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers

SEED = 1

In [2]:
print(lightly.__version__)
print(pl.__version__)
print(torchmetrics.__version__)

1.2.7
1.5.10
0.7.2


In [3]:
# DATA hyperparams
num_workers = 1
moco_batch_size = 512
classifier_train_batch_size = 512
classifier_test_batch_size = 512

In [4]:
# The dataset structure should be like this:
# cifar10/train/
#  L airplane/
#    L 10008_airplane.png
#    L ...
#  L automobile/
#  L bird/
#  L cat/
#  L deer/
#  L dog/
#  L frog/
#  L horse/
#  L ship/
#  L truck/

path_to_train = 'C:/Users/Moon/Desktop/cifar10/cifar10/train'
path_to_test = 'C:/Users/Moon/Desktop/cifar10/cifar10/test'

### 1.1 Augmentations

In [5]:
################### Classifier Augmentations ###################
# Augmentations typically used to train on cifar-10
train_classifier_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=lightly.data.collate.imagenet_normalize['mean'],
        std=lightly.data.collate.imagenet_normalize['std'],
    )
])

# No additional augmentations for the test set
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=lightly.data.collate.imagenet_normalize['mean'],
        std=lightly.data.collate.imagenet_normalize['std'],
    )
])

################### MOCO Augmentations ###################
# MoCo v2 uses SimCLR augmentations, additionally, disable blur
collate_fn = lightly.data.SimCLRCollateFunction(
    input_size=32,
    gaussian_blur=0.,
)

### 1.2 Datasets

In [6]:
################### Classifier Datasets ###################
#Since we also train a linear classifier on the pre-trained moco model we
# reuse the test augmentations here (MoCo augmentations are very strong and
# usually reduce accuracy of models which are not used for contrastive learning.
# Our linear layer will be trained using cross entropy loss and labels provided
# by the dataset. Therefore we chose light augmentations.)
dataset_train_classifier = lightly.data.LightlyDataset(
    input_dir=path_to_train,
    transform=train_classifier_transforms
)

dataset_test = lightly.data.LightlyDataset(
    input_dir=path_to_test,
    transform=test_transforms
)

################### MOCO Dataset ###################
# We use the moco augmentations for training moco
dataset_train_moco = lightly.data.LightlyDataset(
    input_dir=path_to_train
)

### 1.3 Dataloader

In [7]:
################### Classifier Dataloaders ###################
dataloader_train_classifier = torch.utils.data.DataLoader(
    dataset_train_classifier,
    batch_size=classifier_train_batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers
)

dataloader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=classifier_test_batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

################### MOCO Dataloader ###################
dataloader_train_moco = torch.utils.data.DataLoader(
    dataset_train_moco,
    batch_size=moco_batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True,
    num_workers=num_workers
)

## 2. Model

In [8]:
import pytorch_lightning as pl
import lightly  # 참고 : https://docs.lightly.ai/lightly.models.html
import torch.nn as nn
import torch
import numpy as np
import copy

In [9]:
class Classifier(pl.LightningModule):
    def __init__(self, model, max_epochs):
        super().__init__()
        # create a moco based on ResNet
        self.resnet_moco = model
        self.max_epochs = max_epochs
        self.epoch_train_losses = []
        self.epoch_val_accs = []
        self.train_losses = []
        self.val_accs = []

        # freeze the layers of moco
        for p in self.resnet_moco.parameters():  # reset requires_grad
            p.requires_grad = False

        # we create a linear layer for our downstream classification
        # model
        self.fc = nn.Sequential(
                nn.Linear(512, 512),
                nn.ReLU(),
                nn.Linear(512, 10)  # # of classes=10
                )
        
        self.accuracy = torchmetrics.Accuracy()

    def forward(self, x):  # x : (512, 3, 32, 32)
        with torch.no_grad():
            y_hat = self.resnet_moco.backbone(x).squeeze()  # (512, 512, 1, 1) -> (512, 512)
            y_hat = nn.functional.normalize(y_hat, dim=1)
        y_hat = self.fc(y_hat)  # (512, 10)

        return y_hat

    def custom_histogram_weights(self):
        for name, params in self.named_parameters():
            self.logger.experiment.add_histogram(
                name, params, self.current_epoch)
            
    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.forward(x)
        loss = nn.functional.cross_entropy(y_hat, y)  # cross entropy loss
        self.log('train_loss_fc', loss)
        self.epoch_train_losses.append(loss.cpu().detach())
        
        return loss
    
    # TODO: logging histogram doesnt work when using nn.Sequenial for model fc
    # def training_epoch_end(self, outputs):
    #     self.custom_histogram_weights()

    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.forward(x)
        y_hat = torch.nn.functional.softmax(y_hat, dim=1)  
        self.accuracy(y_hat, y)  # accuracy
        val_acc = self.accuracy.compute()
        self.log('val_acc', val_acc,
                 on_epoch=True, prog_bar=True)
        self.epoch_val_accs.append(val_acc.cpu().detach())

    def validation_epoch_end(self, outputs):
        self.train_losses.append(np.mean(self.epoch_train_losses))
        self.val_accs.append(np.mean(self.epoch_val_accs))

    def configure_optimizers(self):
        # IDK why but lr=3 works good when we use 3 layers in fc. They had an lr=30. when it was just 1 layer lol
        optim = torch.optim.SGD(self.fc.parameters(), lr=3.)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.max_epochs)
        return [optim], [scheduler]

In [10]:
class MocoModel(pl.LightningModule):
    def __init__(self, memory_bank_size, moco_max_epochs, downstream_max_epochs=0, dataloader_train_classifier=None, dataloader_test=None, downstream_test_every=0):
        super().__init__()

        self.moco_max_epochs = moco_max_epochs
        self.downstream_max_epochs = downstream_max_epochs
        self.dataloader_train_classifier = dataloader_train_classifier
        self.dataloader_test = dataloader_test
        self.downstream_test_every = downstream_test_every
        
        # create a ResNet backbone(feature extractor) and remove the classification head
        resnet = lightly.models.ResNetGenerator('resnet-18', 1, num_splits=8)  # name, width, # of GPUs
        backbone = nn.Sequential(
            *list(resnet.children())[:-1],  # 최종 classification용 linear layer 제외
            nn.AdaptiveAvgPool2d(1),  # backbone의 최종 dim=512
        )

        # create a moco based on ResNet (backbone위에 MoCoProjectionHead추가)
        self.resnet_moco = lightly.models.MoCo(backbone, num_ftrs=512, m=0.99, batch_shuffle=True)  # 최종 dim=128

        # create contrastive cross entropy loss with the optional memory bank
        self.criterion = lightly.loss.NTXentLoss(  # 참고: https://docs.lightly.ai/lightly.loss.html
            temperature=0.1,
            memory_bank_size=memory_bank_size)
        
    def forward(self, x):
        self.resnet_moco(x)  # 128차원의 representation 출력

    # We provide a helper method to log weights in tensorboard
    # which is useful for debugging.
    def custom_histogram_weights(self):
        for name, params in self.named_parameters():
            self.logger.experiment.add_histogram(
                name, params, self.current_epoch)       
            
    # freeze backbone and train a few linear layers to see progress
    def test_downstream_training(self):
        # copy moco and make classifier
        classifier = Classifier(copy.deepcopy(self.resnet_moco), max_epochs = self.downstream_max_epochs)
        trainer = pl.Trainer(max_epochs = self.downstream_max_epochs, gpus=1, logger=None)
        trainer.fit(
                classifier,
                self.dataloader_train_classifier,
                self.dataloader_test
                )
        train_losses = classifier.train_losses[1:]
        val_accs = classifier.val_accs[1:]
        print(train_losses)
        print(val_accs)
        print('-----')
        min_train_loss = np.min(train_losses)
        max_val_acc = np.max(val_accs)
        self.log('DOWNSTREAM_min_train_loss', min_train_loss)
        self.log('DOWNSTREAM_max_val_acc', max_val_acc)

    def training_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch
        y0, y1 = self.resnet_moco(x0, x1)
        loss = self.criterion(y0, y1)
        self.log('train_loss_ssl', loss)
        
        return loss

    def training_epoch_end(self, outputs):
        self.custom_histogram_weights()
        if self.current_epoch%self.downstream_test_every == 0:
            print('... training downstream classifier...')
            self.test_downstream_training()

    def configure_optimizers(self):
        optim = torch.optim.SGD(self.resnet_moco.parameters(), lr=6e-2,
                                momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.moco_max_epochs)
        
        return [optim], [scheduler]

## 3. Training

In [11]:
# MODEL hyperparams
memory_bank_size = 4096
moco_max_epochs = 3000
downstream_max_epochs = 60
downstream_test_every = 50

In [12]:
model = MocoModel(memory_bank_size, moco_max_epochs, 
                  downstream_max_epochs, dataloader_train_classifier, dataloader_test,
                  downstream_test_every=downstream_test_every)



In [13]:
checkpoint_callback = ModelCheckpoint(
    dirpath='./saved_models/resnet_moco',
    filename='{epoch}-{train_loss_ssl:.2f}',
    save_top_k=5,
    verbose=True,
    monitor='train_loss_ssl',
    mode='min'
)

In [14]:
# use a GPU if available
gpus = 1 if torch.cuda.is_available() else 0
print(f'Using gpu: {bool(gpus)}')
if(gpus == 0): print('--- NOT USING GPUS THIS TAKE LONG TIME ---')

# set up tensorboard logger
tb_logger = pl_loggers.TensorBoardLogger(save_dir='./lightning_logs/', name=f'TESTmoco_{moco_max_epochs}eps')

Using gpu: True


In [None]:
trainer = pl.Trainer(max_epochs=moco_max_epochs, gpus=gpus, callbacks=[checkpoint_callback], logger=tb_logger)
trainer.fit(
    model,
    dataloader_train_moco
)

In [None]:
break

In [15]:
class testMocoModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # create a ResNet backbone and remove the classification head
        resnet = lightly.models.ResNetGenerator('resnet-18', 1, num_splits=8)
        backbone = nn.Sequential(
            *list(resnet.children())[:-1],
            nn.AdaptiveAvgPool2d(1),
        )

        # create a moco based on ResNet
        self.resnet_moco = \
            lightly.models.MoCo(backbone, num_ftrs=512, m=0.99, batch_shuffle=True)

        # create our loss with the optional memory bank
        self.criterion = lightly.loss.NTXentLoss(
            temperature=0.1,
            memory_bank_size=memory_bank_size)

    def forward(self, x):
        self.resnet_moco(x)
        
    def contrastive_loss(self, x0, x1):
        # calculate the contrastive loss for some transformed x -> x0, x1
        # also return grad for each of these
        self.zero_grad()
        x0.requires_grad = True
        x1.requires_grad = True
        y0, y1 = self.resnet_moco(x0, x1)
        loss = self.criterion(y0, y1)
        loss.backward()
        return x0.grad, x1.grad, loss
    
    def contrastive_loss_nograd(self, x0, x1):
        with torch.no_grad():
            y0, y1 = self.resnet_moco(x0, x1)
            loss = self.criterion(y0, y1)
        return loss
        

    # We provide a helper method to log weights in tensorboard
    # which is useful for debugging.
    def custom_histogram_weights(self):
        for name, params in self.named_parameters():
            self.logger.experiment.add_histogram(
                name, params, self.current_epoch)

    def training_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch
        y0, y1 = self.resnet_moco(x0, x1)
        loss = self.criterion(y0, y1)
        self.log('train_loss_ssl', loss)
        return loss

    def training_epoch_end(self, outputs):
        self.custom_histogram_weights()


    def configure_optimizers(self):
        optim = torch.optim.SGD(self.resnet_moco.parameters(), lr=6e-2,
                                momentum=0.9, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs)
        return [optim], [scheduler]

In [16]:
mocomodel = testMocoModel()
mocomodel.load_from_checkpoint('./saved_models/resnet_moco/epoch=311-train_loss_ssl=2.40.ckpt')
mocomodel.eval();

In [17]:
clf = moco_model.Classifier(mocomodel.resnet_moco, max_epochs=25)

In [18]:
moco_model.Classifier

moco_model.Classifier

In [20]:
trainer = pl.Trainer(max_epochs=25, gpus=1)
trainer.fit(
    clf,
    dataloader_train_classifier,
    dataloader_test
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type       | Params
-------------------------------------------
0 | resnet_moco | MoCo       | 23.0 M
1 | fc          | Sequential | 267 K 
2 | accuracy    | Accuracy   | 0     
-------------------------------------------
267 K     Trainable params
23.0 M    Non-trainable params
23.3 M    Total params
93.048    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [23]:
clf

Classifier(
  (resnet_moco): MoCo(
    (backbone): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): SplitBatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): SplitBatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): SplitBatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (shortcut): Sequential()
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): SplitBatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=