In [1]:
!pip install kornia
!pip install torch==1.5.1+cu101  -f https://download.pytorch.org/whl/torch_stable.html
!pip install pytorch-lightning==0.9.0rc2  

Collecting kornia
[?25l  Downloading https://files.pythonhosted.org/packages/c2/60/f0c174c4a2a40b10b04b37c43f5afee3701cc145b48441a2dc5cf9286c3c/kornia-0.3.1-py2.py3-none-any.whl (158kB)
[K     |████████████████████████████████| 163kB 3.5MB/s 
[?25hCollecting torch==1.5.0
[?25l  Downloading https://files.pythonhosted.org/packages/13/70/54e9fb010fe1547bc4774716f11ececb81ae5b306c05f090f4461ee13205/torch-1.5.0-cp36-cp36m-manylinux1_x86_64.whl (752.0MB)
[K     |████████████████████████████████| 752.0MB 24kB/s 
[31mERROR: torchvision 0.6.1+cu101 has requirement torch==1.5.1, but you'll have torch 1.5.0 which is incompatible.[0m
Installing collected packages: torch, kornia
  Found existing installation: torch 1.5.1+cu101
    Uninstalling torch-1.5.1+cu101:
      Successfully uninstalled torch-1.5.1+cu101
Successfully installed kornia-0.3.1 torch-1.5.0
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.5.1+cu101
[?25l  Downloading https://download.

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [1]:
import os
import argparse
import multiprocessing
import torch

from pathlib import Path
from PIL import Image

from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

import sys
sys.path.append('/content/drive/My Drive/hateful_memes_data/')
from byol_pytorch import BYOL

from collections import OrderedDict

In [2]:
resnet = models.resnet50(pretrained=True)
BATCH_SIZE = 32
EPOCHS     = 100
LR         = 3e-4
NUM_GPUS   = 1
IMAGE_SIZE = 256
IMAGE_EXTS = ['.png']
NUM_WORKERS = multiprocessing.cpu_count()

# pytorch lightning module

class SelfSupervisedLearner(pl.LightningModule):
    def __init__(self, net, **kwargs):
        super().__init__()
        self.learner = BYOL(net, **kwargs)
        self.i = 0

    def forward(self, images):
        return self.learner(images)

    def training_step(self, images, _):
        loss = self.forward(images)
        return {'loss': loss}

    def training_epoch_end(self, outputs):
        self.i += 1
        torch.save(self.learner.state_dict(), 
        '/content/drive/My Drive/hateful_memes_data/resnet50_byol_weights_{}.pkl'.format(self.i))
        train_loss_mean = 0
        for output in outputs:
          train_loss = output['loss']
          train_loss_mean += train_loss
        train_loss_mean /= len(outputs)
        tqdm_dict = {'train_loss': train_loss_mean}
        return OrderedDict({'loss': train_loss_mean, 'progress_bar': tqdm_dict, 'log': tqdm_dict})

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)

    def on_before_zero_grad(self, _):
        self.learner.update_moving_average()

# images dataset

def expand_greyscale(t):
    return t.expand(3, -1, -1)

class ImagesDataset(Dataset):
    def __init__(self, folder, image_size):
        super().__init__()
        self.folder = folder
        self.paths = []
        print(folder)

        for path in Path(f'{folder}').glob('**/*'):
            _, ext = os.path.splitext(path)
            if ext.lower() in IMAGE_EXTS:
                self.paths.append(path)

        print(f'{len(self.paths)} images found')

        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Lambda(expand_greyscale)
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        img = img.convert('RGB')
        return self.transform(img)

In [None]:
train_ds = ImagesDataset('/content/drive/My Drive/hateful_memes_data/train/data', IMAGE_SIZE)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

model = SelfSupervisedLearner(
    resnet,
    image_size = IMAGE_SIZE,
    hidden_layer = 'avgpool',
    projection_size = 256,
    projection_hidden_size = 4096,
    moving_average_decay = 0.99
)

#checkpoint_callback = ModelCheckpoint(filepath='/content/drive/My Drive/hateful_memes_data/resnet18_model_checkpoint_{epoch}.ckpt',
#                                     monitor='loss', save_top_k=3)

trainer = pl.Trainer(gpus=NUM_GPUS, max_epochs=EPOCHS) #, checkpoint_callback=checkpoint_callback)
trainer.fit(model, train_loader) 

/content/drive/My Drive/hateful_memes_data/train/data
8500 images found


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type | Params
---------------------------------
0 | learner | BYOL | 72 M  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

In [None]:
checkpoint_path = '/content/drive/My Drive/hateful_memes_data/resnet18_model_checkpoint_epoch_24.ckpt'

In [None]:
loaded_model = SelfSupervisedLearner.load_from_checkpoint(checkpoint_path=checkpoint_path)

TypeError: ignored