<a href="https://colab.research.google.com/github/jmarrietar/ocular/blob/master/notebooks/Refactor_PyTorch_XLA_ResNet18_DR_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## PyTorch/XLA ResNet/DR (GPU or TPU)

In [None]:
import gdown

In [None]:
from google.colab import auth, drive
from googleapiclient.discovery import build
from googleapiclient.http import MediaFileUpload

In [None]:
def download(data, url):
    # Download dataset
    import zipfile
    url = url
    output = "{}.zip".format(data)
    gdown.download(url, output, quiet=False)

    # Uncompress dataset
    local_zip = '{}.zip'.format(data)
    zip_ref = zipfile.ZipFile(local_zip, "r")
    zip_ref.extractall()
    zip_ref.close()

In [None]:
data_samples = {
    "sample@200": "https://drive.google.com/uc?id=1FfV7YyDJvNUCDP5r3-8iQfZ2-xJp_pgb",
    "sample@500": "https://drive.google.com/uc?id=1dHwUqpmSogEdjAB9rwDUL-OKFRUcVXte",
    "sample@1000": "https://drive.google.com/uc?id=1DPZrHrj3Bdte5Dc6NCZ33CAqMG-Oipa2",
    "sample@2000": "https://drive.google.com/uc?id=1PB7uGd-dUnZKnKZpZl-HvE1DVcWgX50F",
    "sample@3000": "https://drive.google.com/uc?id=1_yre5K9YYvJgSrT4xvrI8eD_htucIywA",
    "sample@4000_images": "https://drive.google.com/uc?id=1dqVB8EozEpwWzyuU80AauoQmsiw3Gtm2",
    "sample@20000": "https://drive.google.com/uc?id=1MTDpLzpmhSiZq2jSdmHx2UDPn9FC8gzO",
    "val-voets-tf": "https://drive.google.com/uc?id=1VzVgMGTkBBPG2qbzLunD9HvLzH6tcyrv",
    "train_voets": "https://drive.google.com/uc?id=1AmcFh1MOOZ6aqKm2eO7XEdgmIEqHKTZ5",
    "voets_test_images": "https://drive.google.com/uc?id=15S_V3B_Z3BOjCT3AbO2c887FyS5B0Lyd"
}

In [None]:
UNLABELED = 'train_voets'

In [None]:
URL_UNLABELED = data_samples[UNLABELED]
download(UNLABELED, URL_UNLABELED)

Downloading...
From: https://drive.google.com/uc?id=1AmcFh1MOOZ6aqKm2eO7XEdgmIEqHKTZ5
To: /content/train_voets.zip
3.09GB [00:32, 94.1MB/s]


In [None]:
# Mount Drive
drive.mount('/content/drive')
auth.authenticate_user()

Mounted at /content/drive


In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8.1-cp37-cp37m-linux_x86_64.whl

Collecting cloud-tpu-client==0.10
  Downloading https://files.pythonhosted.org/packages/56/9f/7b1958c2886db06feb5de5b2c191096f9e619914b6c31fdf93999fdbbd8b/cloud_tpu_client-0.10-py3-none-any.whl
Collecting torch-xla==1.8.1
[?25l  Downloading https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8.1-cp37-cp37m-linux_x86_64.whl (145.0MB)
[K     |████████████████████████████████| 145.0MB 45kB/s 
Collecting google-api-python-client==1.8.0
[?25l  Downloading https://files.pythonhosted.org/packages/9a/b4/a955f393b838bc47cbb6ae4643b9d0f90333d3b4db4dc1e819f36aad18cc/google_api_python_client-1.8.0-py3-none-any.whl (57kB)
[K     |████████████████████████████████| 61kB 3.0MB/s 
[31mERROR: earthengine-api 0.1.264 has requirement google-api-python-client<2,>=1.12.1, but you'll have google-api-python-client 1.8.0 which is incompatible.[0m
Installing collected packages: google-api-python-client, cloud-tpu-client, torch-xla
  Found existing installation: google-api-python-client 1.12.8
  

Only run the below commented cell if you would like a nightly release

In [None]:
import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
import torchvision
from torchvision import datasets, transforms
import logging
from torch.utils.tensorboard import SummaryWriter


In [None]:
# VERSION = "nightly"  #@param ["nightly", "20200516"]  # or YYYYMMDD format
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version $VERSION

In [None]:
# PyTorch/XLA GPU Setup (only if GPU runtime)
import os
if os.environ.get('COLAB_GPU', '0') == '1':
  os.environ['GPU_NUM_DEVICES'] = '1'
  os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/usr/local/cuda/'

### Define Parameters



In [None]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

In [None]:
def accuracy_func(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [None]:
# Define Parameters
FLAGS = {}
FLAGS['data_dir'] = "/tmp/cifar"
FLAGS['batch_size'] = 64
FLAGS['num_workers'] = 2
FLAGS['learning_rate'] = 0.00001
FLAGS['momentum'] = 0.9
FLAGS['num_epochs'] = 100 
FLAGS['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1
FLAGS['log_steps'] = 100
FLAGS['metrics_debug'] = False

In [None]:
import numpy as np

np.random.seed(0)

import numpy as np
import torch
from torch import nn
from torchvision.transforms import transforms
from torchvision import transforms, datasets

np.random.seed(0)

class GaussianBlur(object):
    """blur a single image on CPU"""
    def __init__(self, kernel_size):
        radias = kernel_size // 2
        kernel_size = radias * 2 + 1
        self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
                                stride=1, padding=0, bias=False, groups=3)
        self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
                                stride=1, padding=0, bias=False, groups=3)
        self.k = kernel_size
        self.r = radias

        self.blur = nn.Sequential(
            nn.ReflectionPad2d(radias),
            self.blur_h,
            self.blur_v
        )

        self.pil_to_tensor = transforms.ToTensor()
        self.tensor_to_pil = transforms.ToPILImage()

    def __call__(self, img):
        img = self.pil_to_tensor(img).unsqueeze(0)

        sigma = np.random.uniform(0.1, 2.0)
        x = np.arange(-self.r, self.r + 1)
        x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
        x = x / x.sum()
        x = torch.from_numpy(x).view(1, -1).repeat(3, 1)

        self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
        self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))

        with torch.no_grad():
            img = self.blur(img)
            img = img.squeeze()

        img = self.tensor_to_pil(img)

        return img

class ContrastiveLearningViewGenerator(object):
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform, n_views=2):
        self.base_transform = base_transform
        self.n_views = n_views

    def __call__(self, x):
        #return [self.base_transform(x) for i in range(self.n_views)][0]
        return [self.base_transform(x) for i in range(self.n_views)]

In [None]:
def get_simclr_pipeline_transform(size, s=1):
    """Return a set of data augmentation transformations as described in the SimCLR paper."""
    color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
    data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.RandomApply([color_jitter], p=0.8),
                                            transforms.RandomGrayscale(p=0.2),
                                            GaussianBlur(kernel_size=int(0.1 * size)),
                                            transforms.ToTensor()])
    return data_transforms

In [None]:
def info_nce_loss(features, device):

    labels = torch.cat([torch.arange(FLAGS['batch_size']) for i in range(2)], dim=0) # modifique a 2
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()

    labels = labels.to(device)
    
    features = F.normalize(features, dim=1)

    similarity_matrix = torch.matmul(features, features.T)
    # assert similarity_matrix.shape == (
    #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
    # assert similarity_matrix.shape == labels.shape

    # discard the main diagonal from both: labels and similarities matrix
    
    mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)

    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
    # assert similarity_matrix.shape == labels.shape

    # select and combine multiple positives
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

    # select only the negatives the negatives
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)

    TEMPERATURE = 0.07 # Yo lo Hardcodie

    logits = logits / TEMPERATURE
    return logits, labels

In [None]:
import torch.nn as nn
import torchvision.models as models

class BaseSimCLRException(Exception):
    """Base exception"""


class InvalidBackboneError(BaseSimCLRException):
    """Raised when the choice of backbone Convnet is invalid."""


class InvalidDatasetSelection(BaseSimCLRException):
    """Raised when the choice of dataset is invalid."""


class ResNetSimCLR(nn.Module):

    def __init__(self, base_model, out_dim=128):
        super(ResNetSimCLR, self).__init__()
        self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),
                            "resnet50": models.resnet50(pretrained=True)}

        self.backbone = self._get_basemodel(base_model)
        dim_mlp = self.backbone.fc.in_features

        # add mlp projection head
        self.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)

    def _get_basemodel(self, model_name):
        try:
            model = self.resnet_dict[model_name]
        except KeyError:
            raise InvalidBackboneError(
                "Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")
        else:
            return model

    def forward(self, x):
        return self.backbone(x)

In [None]:
SERIAL_EXEC = xmp.MpSerialExecutor()

#WRAPPED_MODEL = xmp.MpModelWrapper(ResNetSimCLR(base_model='resnet50', out_dim=128))
WRAPPED_MODEL = xmp.MpModelWrapper(ResNetSimCLR(base_model='resnet50'))

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth


HBox(children=(FloatProgress(value=0.0, max=102502400.0), HTML(value='')))




In [None]:
def train_resnet():
  torch.manual_seed(1)

  def get_dataset():

    train_dataset = datasets.ImageFolder(root="{}".format(UNLABELED), 
                                         transform=ContrastiveLearningViewGenerator(
                                        get_simclr_pipeline_transform(224),n_views=2))

    return train_dataset
  
  # Using the serial executor avoids multiple processes
  # to download the same data.
  train_dataset = SERIAL_EXEC.run(get_dataset)
  
  train_sampler = torch.utils.data.distributed.DistributedSampler(
      train_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True)
  
  train_loader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=FLAGS['batch_size'],
      sampler=train_sampler,
      num_workers=FLAGS['num_workers'],
      drop_last=True)
  

  # Scale learning rate to num cores
  learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()

  # Get loss function, optimizer, and model
  device = xm.xla_device()
  model = WRAPPED_MODEL.to(device)

  optimizer = torch.optim.Adam(model.parameters(), 
                               learning_rate, 
                               weight_decay=5e-4)

  criterion = torch.nn.CrossEntropyLoss().to(device)  # YO

  def train_loop_fn(loader):
    tracker = xm.RateTracker()
    model.train()

    for x, (data, _) in enumerate(loader):
      optimizer.zero_grad()

      data = torch.cat(data, dim=0)

      output = model(data)
      logits, labels = info_nce_loss(output, device) # YO

      loss = criterion(logits, labels) # YO

      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS['batch_size'])

      top1, top5 = accuracy_func(logits, labels, topk=(1, 5))

      if x % FLAGS['log_steps'] == 0:
        print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
            xm.get_ordinal(), x, loss.item(), tracker.rate(),
            tracker.global_rate(), time.asctime()), flush=True)
        print(f"Top1 accuracy: {top1[0]}")


  # Train and eval loops
  accuracy = 0.0
  data, pred, target = None, None, None
  for epoch in range(1, FLAGS['num_epochs'] + 1):
    para_loader = pl.ParallelLoader(train_loader, [device])
    train_loop_fn(para_loader.per_device_loader(device))
    xm.master_print("Finished training epoch {}".format(epoch))

    xm.save(
            model.state_dict(),
            "drive/MyDrive/Colab Notebooks/SimCLR/models/SimCLR-1-DR-pytorch/net-DR-SimCLR.pt"
        )

    if FLAGS['metrics_debug']:
      xm.master_print(met.metrics_report(), flush=True)


  return accuracy, data, pred, target

In [None]:
# Start training processes
def _mp_fn(rank, flags):
  global FLAGS
  FLAGS = flags
  torch.set_default_tensor_type('torch.FloatTensor')
  accuracy, data, pred, target = train_resnet()


xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],
          start_method='fork')

[xla:0](0) Loss=4.53163 Rate=2.80 GlobalRate=2.80 Time=Tue May 25 15:12:22 2021
[xla:6](0) Loss=4.54450 Rate=2.66 GlobalRate=2.66 Time=Tue May 25 15:12:26 2021
Top1 accuracy: 9.375
Top1 accuracy: 9.375
[xla:3](0) Loss=4.64167 Rate=2.49 GlobalRate=2.49 Time=Tue May 25 15:12:29 2021
Top1 accuracy: 5.46875
[xla:2](0) Loss=4.63687 Rate=2.45 GlobalRate=2.45 Time=Tue May 25 15:12:32 2021
Top1 accuracy: 9.375
[xla:7](0) Loss=4.65762 Rate=2.75 GlobalRate=2.75 Time=Tue May 25 15:12:34 2021
Top1 accuracy: 4.6875
[xla:4](0) Loss=4.65069 Rate=4.42 GlobalRate=4.42 Time=Tue May 25 15:12:50 2021
Top1 accuracy: 7.03125
[xla:5](0) Loss=4.61165 Rate=3.64 GlobalRate=3.64 Time=Tue May 25 15:12:55 2021
Top1 accuracy: 1.5625
[xla:1](0) Loss=4.64581 Rate=3.58 GlobalRate=3.58 Time=Tue May 25 15:12:57 2021
Top1 accuracy: 7.03125
[xla:6](100) Loss=1.32761 Rate=7.62 GlobalRate=10.60 Time=Tue May 25 15:22:12 2021
[xla:3](100) Loss=1.42676 Rate=7.59 GlobalRate=10.63 Time=Tue May 25 15:22:12 2021
[xla:5](100) Loss=