In [1]:

# ! gdown https://drive.google.com/uc?id=1oij5U35z8qe_cA1LNuDw882dQU1jWmC9
# ! unzip histology-images-query-competition.zip

In [2]:
! gdown https://drive.google.com/uc?id=1CBAM-L2j2qKLJ2jXGAMC4SD9ZMwIgZSo
! gdown https://drive.google.com/uc?id=171edMp03GqaRb6jVyVCM6mUq30pOUPZG

Downloading...
From: https://drive.google.com/uc?id=1CBAM-L2j2qKLJ2jXGAMC4SD9ZMwIgZSo
To: /content/byol.py
100% 7.97k/7.97k [00:00<00:00, 1.05MB/s]
Downloading...
From: https://drive.google.com/uc?id=171edMp03GqaRb6jVyVCM6mUq30pOUPZG
To: /content/validation_ground_truth.csv
100% 5.22k/5.22k [00:00<00:00, 8.78MB/s]


In [3]:
# ! pip install byol-pytorch

In [4]:
import PIL
from tqdm import tqdm
import os
import pandas as pd
import math
import time
import shutil
import random


import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision import models
from torch.utils.data import Dataset,DataLoader


from byol import BYOL


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
class ValidDataSet(Dataset):
    def __init__(self, images_folder_path, label_dic, transform=None):
        self.images_folder_path = images_folder_path
        self.label_dic = label_dic
        if transform is not None:
          self.transform=transform
        else:
          self.transform=transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.label_dic)
    def __getitem__(self, index):
        element = self.label_dic.iloc[index]
        predic = element['prediction']
        query = element['query']
        img_name=query.split('_')
        images=[]
        # print(img_name)
        for i in range(2):
          images.append(PIL.Image.open(os.path.join(
              self.images_folder_path, img_name[i]+'.png')))
          images[i] = self.transform(images[i])
          
        return images[0],images[1], predic

In [6]:

class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]

class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(PIL.ImageFilter.GaussianBlur(radius=sigma))
        return x

augmentation = [
        transforms.RandomResizedCrop(256, scale=(0.2, 1.)),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    ]

divide_transform=TwoCropsTransform(transforms.Compose(augmentation))
transform=transforms.Compose(augmentation)

In [7]:
batch_size=32
workers=2

train_dataset = datasets.ImageFolder(
        'train/',transform=divide_transform)

train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,num_workers=workers, pin_memory=True, sampler=None, drop_last=True)


valid_csv=pd.read_csv('validation_ground_truth.csv')
valid_set=ValidDataSet('train/train',valid_csv,transform=transform)
valid_loader=DataLoader(valid_set, batch_size=1, num_workers=workers, pin_memory=True)

In [8]:


resnet = models.resnet50(pretrained=True)

model = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

model.to(device)

opt = torch.optim.Adam(model.parameters(), lr=3e-4)



In [9]:
import time

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def train(train_loader, model, optimizer, epoch,print_freq):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()

    for i, (images, _) in enumerate(train_loader):
      data_time.update(time.time() - end)

      # images[0] = images[0].to(device)
      # images[1] = images[1].to(device)

      loss = model([images[0].to(device),images[1].to(device)])
      losses.update(loss.item(), images[0].size(0))

      opt.zero_grad()
      loss.backward()
      opt.step()
      model.update_moving_average() # update moving average of target encoder

      # measure elapsed time
      batch_time.update(time.time() - end)
      end = time.time()

      if i % print_freq == 0:
            progress.display(i)
      torch.save(model.state_dict(), f'checkpoint_{epoch}.pth')
    
def valid(valid_loader,model,threshold):
    # losses = AverageMeter('Loss', ':.4f')
    # progress = ProgressMeter(
    #     len(train_loader),
    #     [losses],
    #     prefix="Epoch: [{}]".format(epoch))
    valid_loss=0
    predict = [[0,0],[0,0]]

    model.eval()
    with torch.no_grad():
      for images1,images2,labels in tqdm(valid_loader):
          # images1 = images1.to(device)
          # images2 = images2.to(device)
          loss = model([images1.to(device),images2.to(device)])
          # losses.update(loss.item(), images1.size(0))

          # print(loss)
          for i in range(len(labels)):
            predict[labels[i]][loss<threshold]+=1
            valid_loss+=loss if labels[i] else 1-loss
      print(f'valid_loss:{valid_loss},predict:',predict)

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

In [None]:

epoches=10

for epoch in range(epoches):
    train(train_loader,model,opt,epoch,10)
    valid(valid_loader,model,0.5)
    save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer' : opt.state_dict(),
            }, is_best=False, filename='checkpoint_{:04d}.pth.tar'.format(epoch))

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')

Epoch: [0][  0/193]	Time  4.405 ( 4.405)	Data  1.897 ( 1.897)	Loss 3.8770 (3.8770)
Epoch: [0][ 10/193]	Time  3.443 ( 3.867)	Data  1.134 ( 1.538)	Loss 1.7444 (2.0892)
Epoch: [0][ 20/193]	Time  9.184 ( 5.227)	Data  6.808 ( 2.880)	Loss 1.7493 (1.9177)
Epoch: [0][ 30/193]	Time  9.434 ( 5.604)	Data  7.068 ( 3.251)	Loss 1.7894 (1.8632)
Epoch: [0][ 40/193]	Time  3.717 ( 5.712)	Data  1.391 ( 3.354)	Loss 1.7268 (1.8276)
Epoch: [0][ 50/193]	Time  6.374 ( 5.867)	Data  4.008 ( 3.507)	Loss 1.5362 (1.7891)
Epoch: [0][ 60/193]	Time  7.411 ( 5.971)	Data  4.981 ( 3.608)	Loss 1.4078 (1.7313)
Epoch: [0][ 70/193]	Time  6.211 ( 6.011)	Data  3.853 ( 3.649)	Loss 1.2377 (1.6660)
Epoch: [0][ 80/193]	Time  6.777 ( 6.059)	Data  4.405 ( 3.695)	Loss 0.7820 (1.5928)
Epoch: [0][ 90/193]	Time  6.630 ( 6.095)	Data  4.255 ( 3.731)	Loss 0.8897 (1.5192)
Epoch: [0][100/193]	Time  6.306 ( 6.126)	Data  3.944 ( 3.761)	Loss 0.9080 (1.4518)
Epoch: [0][110/193]	Time  7.170 ( 6.151)	Data  4.806 ( 3.785)	Loss 0.6664 (1.3912)
Epoc

100%|██████████| 186/186 [00:40<00:00,  4.55it/s]


valid_loss:-200.47352600097656,predict: [[114, 0], [66, 6]]
Epoch: [1][  0/193]	Time  4.205 ( 4.205)	Data  1.716 ( 1.716)	Loss 0.4759 (0.4759)
Epoch: [1][ 10/193]	Time  6.674 ( 4.521)	Data  4.221 ( 2.173)	Loss 0.3498 (0.4484)
Epoch: [1][ 20/193]	Time  6.342 ( 5.399)	Data  3.980 ( 3.037)	Loss 0.6454 (0.4496)
Epoch: [1][ 30/193]	Time  6.432 ( 5.721)	Data  4.078 ( 3.355)	Loss 0.5402 (0.4313)
Epoch: [1][ 40/193]	Time  5.940 ( 5.872)	Data  3.568 ( 3.503)	Loss 0.4410 (0.4373)
Epoch: [1][ 50/193]	Time  7.311 ( 5.988)	Data  4.925 ( 3.616)	Loss 0.3243 (0.4252)
Epoch: [1][ 60/193]	Time  5.923 ( 6.047)	Data  3.557 ( 3.675)	Loss 0.5708 (0.4169)
Epoch: [1][ 70/193]	Time  6.085 ( 6.094)	Data  3.626 ( 3.720)	Loss 0.2387 (0.4101)
Epoch: [1][ 80/193]	Time  5.883 ( 6.127)	Data  3.509 ( 3.753)	Loss 0.3604 (0.4076)
