In [2]:
! pip install numpy
! pip install --pre -U torch torchvision -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html

Looking in links: https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html
Collecting torch
  Using cached https://download.pytorch.org/whl/nightly/cu102/torch-1.7.0.dev20200702-cp36-cp36m-linux_x86_64.whl (893.2 MB)
Collecting torchvision
  Using cached https://download.pytorch.org/whl/nightly/cu102/torchvision-0.8.0.dev20200701-cp36-cp36m-linux_x86_64.whl (5.9 MB)
[31mERROR: torchvision 0.8.0.dev20200701 has requirement torch==1.7.0.dev20200701, but you'll have torch 1.7.0.dev20200702 which is incompatible.[0m
Installing collected packages: torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 1.3.1
    Uninstalling torch-1.3.1:
      Successfully uninstalled torch-1.3.1
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.4.2
    Uninstalling torchvision-0.4.2:
      Successfully uninstalled torchvision-0.4.2
Successfully installed torch-1.7.0.dev20200702 torchvision-0.8.0.dev20200701


In [3]:
! pip install pytorch-lightning

Collecting pytorch-lightning
  Downloading pytorch_lightning-0.8.4-py3-none-any.whl (304 kB)
[K     |████████████████████████████████| 304 kB 4.5 MB/s eta 0:00:01
[?25hCollecting PyYAML>=5.1
  Downloading PyYAML-5.3.1.tar.gz (269 kB)
[K     |████████████████████████████████| 269 kB 13.9 MB/s eta 0:00:01
[?25hCollecting tqdm>=4.41.0
  Downloading tqdm-4.47.0-py2.py3-none-any.whl (66 kB)
[K     |████████████████████████████████| 66 kB 12.0 MB/s eta 0:00:01
Building wheels for collected packages: PyYAML
  Building wheel for PyYAML (setup.py) ... [?25ldone
[?25h  Created wheel for PyYAML: filename=PyYAML-5.3.1-cp36-cp36m-linux_x86_64.whl size=45919 sha256=6420b08d2a2f9d7f97a06d8a95662cd3c32468e455f2c683d62f723e12f5f9ef
  Stored in directory: /root/.cache/pip/wheels/e5/9d/ad/2ee53cf262cba1ffd8afe1487eef788ea3f260b7e6232a80fc
Successfully built PyYAML
Installing collected packages: PyYAML, tqdm, pytorch-lightning
Successfully installed PyYAML-5.3.1 pytorch-lightning-0.8.4 tqdm-4.47.0


In [3]:
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision import transforms

from pytorch_lightning import Trainer

class SimCLRDataset(Dataset):
    def __init__(self, dataset):
        """Initialize a wrapper of a generic image classification dataset for SimCLR training.

        Args:
            dataset (torch.utils.data.Dataset): an image PyTorch dataset - when iterating over it
                it should return something of the form (image) or (image, label).
        """
        self.dataset = dataset

    def __getitem__(self, index):
        dataset_item = self.dataset[index]
        if type(dataset_item) is tuple:
            image = dataset_item[0]
        else:
            image = dataset_item
        return image, image

    def __len__(self):
        return len(self.dataset)

    @staticmethod
    def mixup(x, alpha=0.4):
        batch_size = x.size()[0] // 2
        if alpha > 0:
            lam = np.random.beta(alpha, alpha, batch_size)
            lam = np.concatenate(
                [lam[:, None], 1 - lam[:, None]], 1
            ).max(1)[:, None, None, None]
            lam = torch.from_numpy(lam).float()
            if torch.cuda.is_available():
                lam = lam.cuda()
        else:
            lam = 1.
        # This is SimCLR specific - we want to use the same mixing for the augmented pairs
        lam = torch.cat([lam, lam])
        index = torch.randperm(batch_size)
        # This is SimCLR specific - we want to use the same permutation on the augmented pairs
        index = torch.cat([index, batch_size + index])
        if torch.cuda.is_available():
            index = index.cuda()
        mixed_x = lam * x + (1 - lam) * x[index, :]

        return mixed_x, lam


def imagenet_normalize_transform():
    return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


def get_train_transforms(size=224, color_jitter_prob=0.8, grayscale_prob=0.2):
    color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
    return transforms.Compose([
        transforms.RandomResizedCrop(size=(size, size)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([color_jitter], p=color_jitter_prob),
        transforms.RandomGrayscale(p=grayscale_prob),
        transforms.ToTensor(),
        imagenet_normalize_transform()
    ])


def get_val_transforms(size=224):
    return transforms.Compose([
        transforms.Resize(size=(size, size)),
        transforms.ToTensor(),
        imagenet_normalize_transform()
    ])


In [4]:
import torch.nn as nn


class NTXEntCriterion(nn.Module):
    """Normalized, temperature-scaled cross-entropy criterion, as suggested in the SimCLR paper.

    Parameters:
        temperature (float, optional): temperature to scale the confidences. Defaults to 0.5.
    """
    criterion = nn.CrossEntropyLoss(reduction="sum")
    similarity = nn.CosineSimilarity(dim=2)

    def __init__(self, temperature=0.5):
        super(NTXEntCriterion, self).__init__()
        self.temperature = temperature
        self.batch_size = None
        self.mask = None

    def mask_correlated_samples(self, batch_size):
        """Masks examples in a batch and it's augmented pair for computing the valid summands for
            the criterion.

        Args:
            batch_size (int): batch size of the individual batch (not including it's augmented pair)

        Returns:
            torch.Tensor: a mask (tensor of 0s and 1s), where 1s indicates a pair of examples in a
                batch that will contribute to the overall batch loss
        """
        mask = torch.ones((batch_size * 2, batch_size * 2), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask

    def compute_similarities(self, z_i, z_j, temperature):
        """Computes the similarities between two projections `z_i` and `z_j`, scaling based on
            `temperature`.

        Args:
            z_i (torch.Tensor): projection of a batch
            z_j (torch.Tensor): projection of the augmented pair for the batch
            temperature (float): temperature to scale the similarity by

        Returns:
            torch.Tensor: tensor of similarities for the positive and negative pairs
        """
        batch_size = len(z_i)
        mask = self.mask_correlated_samples(batch_size)

        p1 = torch.cat((z_i, z_j), dim=0)
        sim = self.similarity(p1.unsqueeze(1), p1.unsqueeze(0)) / temperature

        sim_i_j = torch.diag(sim, batch_size)
        sim_j_i = torch.diag(sim, -batch_size)

        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(
            batch_size * 2, 1
        )
        negative_samples = sim[mask].reshape(batch_size * 2, -1)

        logits = torch.cat((positive_samples, negative_samples), dim=1)
        return logits

    def forward(self, z):
        """Computes the loss for a batch and its augmented pair.

        Args:
            z (torch.Tensor): tensor of a batch and it's augmented pair, concatenated

        Returns:
            torch.Tensor: loss for the given batch
        """
        double_batch_size = len(z)
        batch_size = double_batch_size // 2
        z_i, z_j = z[:double_batch_size // 2], z[double_batch_size // 2:]
        if self.batch_size is None or batch_size != self.batch_size:
            self.batch_size = batch_size
            self.mask = None

        if self.mask is None:
            self.mask = self.mask_correlated_samples(self.batch_size)

        logits = self.compute_similarities(z_i, z_j, self.temperature)
        labels = torch.zeros(self.batch_size * 2).long()
        logits, labels = logits.to(z.device), labels.to(z.device)
        loss = self.criterion(logits, labels)
        loss /= 2 * self.batch_size
        return loss

In [5]:
from pytorch_lightning.core.lightning import LightningModule
from torch.utils.data import DataLoader
import torchvision

class SimCLRModel(LightningModule):
    """SimCLR training network for a generic torchvision model (restricted to `allowed_models`). """

    allowed_models = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
    allowed_datasets = ['CIFAR10', 'CIFAR100', 'STL10', 'SVHN']

    def __init__(
        self, model_name='resnet18', pretrained=True, projection_dim=64, temperature=0.5,
        batch_size=128, image_size=224, save_hparams=True
    ):
        super().__init__()
        layers = list(getattr(torchvision.models, model_name)(pretrained=pretrained).children())
        self.model = nn.Sequential(*layers[:-1])
        self.projection_head = nn.Linear(layers[-1].in_features, projection_dim)
        self.loss = NTXEntCriterion(temperature=temperature)
        self.batch_size = batch_size
        self.image_size = image_size
        self.prepare_data()

    def forward(self, x):
#         print("Forwarding")
        out = self.model(x)
        out = out.view(x.size(0), -1)
        out = self.projection_head(out)
        return out

    def training_step(self, batch, batch_idx):
#         print("Training step")
        projections = self(batch)
        loss = self.loss(projections)
        tensorboard_logs = {'train_loss': loss}
        self.logger.scalar('loss', loss)
        return {'loss': loss, 'log': tensorboard_logs}

    def training_epoch_end(self, outputs):
#         print("Finished Epoch")
        loss_mean = torch.stack([x['loss'] for x in outputs]).mean()
        return {'train_loss': loss_mean}

#     def validation_step(self, batch, batch_idx):
# #         print("Validation step")
#         projections = self(batch)
#         loss = self.loss(projections)
#         tensorboard_logs = {'val_loss': loss}
#         return {'loss': loss, 'log': tensorboard_logs}

#     def validation_epoch_end(self, outputs):
# #         print("Finished Epoch")
#         val_loss_mean = torch.stack([x['loss'] for x in outputs]).mean()
#         return {'val_loss': val_loss_mean}

    def configure_optimizers(self):
        return torch.optim.Adam([
            {'params': self.model.parameters(), 'lr': 0.00001},
            {'params': self.projection_head.parameters(), 'lr': 0.001}
        ])

    def prepare_data(self):
#         print("Getting Data")
        train_transforms, val_transforms = (
            get_train_transforms(size=self.image_size),
            get_val_transforms(size=self.image_size)
        )
        train_dataset = torchvision.datasets.ImageFolder(
            '/tf/data/combined', transform = train_transforms)
        self.train_dataset = SimCLRDataset(train_dataset)
#         val_dataset = torchvision.datasets.ImageFolder(
#             '/tf/data/combined', transform = val_transforms)
#         self.val_dataset = SimCLRDataset(val_dataset)
#         print("Finished getting data")

    def collate_fn(self, batch):
#         print("Collating data")
        return torch.cat([torch.stack([b[0] for b in batch]), torch.stack([b[1] for b in batch])])

    def train_dataloader(self):
#         print("Grabbing dataloader")
        return DataLoader(
            self.train_dataset, batch_size=self.batch_size, num_workers=64, shuffle=True,
            collate_fn=self.collate_fn
        )

#     def val_dataloader(self):
# #         print("Grabbing dataloader")
#         return DataLoader(
#             self.val_dataset, batch_size=self.batch_size, num_workers=64, shuffle=False,
#             collate_fn=self.collate_fn
#         )

In [6]:
model = SimCLRModel(
    model_name = 'resnet50',
    pretrained = True,
    batch_size = 1792,
    image_size = 224
)

In [7]:
model = SimCLRModel.load_from_checkpoint(checkpoint_path='/tf/data/models/simclr/checkpointepoch=98.ckpt')

In [10]:
torch.save(model.state_dict(), '/tf/data/models/simclr/simclr-epoch98.pth')

In [2]:
from pytorch_lightning.callbacks import ModelCheckpoint


checkpoint_callback = ModelCheckpoint(
    filepath = '/tf/data/models/simclr/', prefix = "checkpoint",
    monitor = "val_loss", mode = "min", save_top_k = 3
)

train_params = dict(
#     accumulate_grad_batches = 1, # hparams.gradient_accumulation_steps,
    gpus = 3,
    max_epochs = 5, # hparams.num_train_epochs,
    early_stop_callback = False,
#     gradient_clip_val = 3, # hparams.max_grad_norm,
#     checkpoint_callback = checkpoint_callback,
    num_nodes = 1,
    prepare_data_per_node = False,
    distributed_backend = 'ddp',
#     precision = 16
#     num_workers = 0
#     callbacks=[LoggingCallback()],
)

trainer = Trainer(**train_params)



NameError: name 'Trainer' is not defined

In [6]:
%load_ext tensorboard
%tensorboard --bind_all --logdir lightning_logs/

In [None]:
# trainer = Trainer.from_argparse_args(args)
trainer.fit(model)