In [None]:
pip install kornia
pip install torch==1.5.1+cu101  -f https://download.pytorch.org/whl/torch_stable.html
pip install pytorch_lightning    

In [8]:
import os
import argparse
import multiprocessing
from pathlib import Path
from PIL import Image

from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

from byol_pytorch import BYOL
import pytorch_lightning as pl

In [None]:
resnet = models.resnet50(pretrained=True)
BATCH_SIZE = 32
EPOCHS     = 100
LR         = 3e-4
NUM_GPUS   = 1
IMAGE_SIZE = 256
IMAGE_EXTS = ['.png']
NUM_WORKERS = multiprocessing.cpu_count()

# pytorch lightning module

class SelfSupervisedLearner(pl.LightningModule):
    def __init__(self, net, **kwargs):
        super().__init__()
        self.learner = BYOL(net, **kwargs)

    def forward(self, images):
        return self.learner(images)

    def training_step(self, images, _):
        loss = self.forward(images)
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=LR)

    def on_before_zero_grad(self, _):
        self.learner.update_moving_average()

# images dataset

def expand_greyscale(t):
    return t.expand(3, -1, -1)

class ImagesDataset(Dataset):
    def __init__(self, folder, image_size):
        super().__init__()
        self.folder = folder
        self.paths = []
        print(folder)

        for path in Path(f'{folder}').glob('**/*'):
            _, ext = os.path.splitext(path)
            if ext.lower() in IMAGE_EXTS:
                self.paths.append(path)

        print(f'{len(self.paths)} images found')

        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Lambda(expand_greyscale)
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        img = img.convert('RGB')
        return self.transform(img)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
ds = ImagesDataset('/content/drive/My Drive/hateful_memes_data/train/data', IMAGE_SIZE)
train_loader = DataLoader(ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

model = SelfSupervisedLearner(
    resnet,
    image_size = IMAGE_SIZE,
    hidden_layer = 'avgpool',
    projection_size = 256,
    projection_hidden_size = 4096,
    moving_average_decay = 0.99
)

trainer = pl.Trainer(gpus=NUM_GPUS, max_epochs=EPOCHS)
trainer.fit(model, train_loader)