# Pytorch + TPU + Lightning

> Pytorch🔥 PyTroch - Lightning⚡️ TPU⏱

In [1]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5115  100  5115    0     0   9283      0 --:--:-- --:--:-- --:--:--  9283
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-nightly ...
Found existing installation: torch 1.5.0
Uninstalling torch-1.5.0:
  Successfully uninstalled torch-1.5.0
Found existing installation: torchvision 0.6.0a0+35d732a
Uninstalling torchvision-0.6.0a0+35d732a:
Done updating TPU runtime
  Successfully uninstalled torchvision-0.6.0a0+35d732a
Copying gs://tpu-pytorch/wheels/torch-nightly-cp37-cp37m-linux_x86_64.whl...
| [1 files][110.1 MiB/110.1 MiB]                                                
Operation completed over 1 objects/110.1 MiB.                                    
Copying gs://tpu-pytorch/wheels/torch_xla-nightly-cp37-cp37m-linux_x86_64.whl...
/ [1 files][127.2 MiB/127.2 MiB]                                    

## Dependencies

In [2]:
!pip install wtfml==0.0.3
!pip install efficientnet_pytorch
!pip install pytorch-lightning

Collecting wtfml==0.0.3
  Downloading wtfml-0.0.3-py3-none-any.whl (10 kB)
Installing collected packages: wtfml
Successfully installed wtfml-0.0.3
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
Collecting efficientnet_pytorch
  Downloading efficientnet_pytorch-0.6.3.tar.gz (16 kB)
Building wheels for collected packages: efficientnet-pytorch
  Building wheel for efficientnet-pytorch (setup.py) ... [?25ldone
[?25h  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.6.3-py3-none-any.whl size=12419 sha256=fb89b0d4564fdc47d744bf30acbc94691424eec03f20967b815d85ca803bda51
  Stored in directory: /root/.cache/pip/wheels/90/6b/0c/f0ad36d00310e65390b0d4c9218ae6250ac579c92540c9097a
Successfully built efficientnet-pytorch
Installing collected packages: efficientnet-pytorch
Successfully installed efficientnet-pytorch-0.6.3
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' co

In [19]:
import gc
import os
import torch
import albumentations

import numpy as np
import pandas as pd

import torch.nn as nn
from sklearn import metrics
from sklearn import model_selection
from torch.nn import functional as F

# from wtfml.engine import Engine
# from wtfml.utils import EarlyStopping
# from wtfml.data_loaders.image import ClassificationDataLoader
from torch.utils.data import Dataset,DataLoader

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

import efficientnet_pytorch

from pytorch_lightning import LightningModule,Trainer
from pytorch_lightning.callbacks import EarlyStopping,Callback

from PIL import Image
from pathlib import Path
from tqdm import trange

In [20]:
MEAN = [0.80619959, 0.62115946, 0.59133584]
STD = [0.15061945, 0.17709774, 0.20317172]

In [21]:
def get_aug(train = False):
    mean = MEAN
    std = STD
    if train:
        train_aug = albumentations.Compose(
        [albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
        albumentations.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15),
            albumentations.Flip(p=0.5)])
        return train_aug
    else:
        valid_aug = albumentations.Compose(
        [albumentations.Normalize(mean, std, max_pixel_value=255.0,always_apply=True)])
        return valid_aug

In [22]:
aug = get_aug(train=True)

In [23]:
# create folds
df = pd.read_csv("../input/siim-isic-melanoma-classification/train.csv")
df["kfold"] = -1    
df = df.sample(frac=1).reset_index(drop=True)
y = df.target.values
kf = model_selection.StratifiedKFold(n_splits=5)

for f, (t_, v_) in enumerate(kf.split(X=df, y=y)):
    df.loc[v_, 'kfold'] = f

df.to_csv("train_folds.csv", index=False)

In [24]:
TRAIN_DIR = Path("../input/siic-isic-224x224-images/train/")

In [25]:
df

Unnamed: 0,image_name,patient_id,sex,age_approx,anatom_site_general_challenge,diagnosis,benign_malignant,target,kfold
0,ISIC_5234605,IP_1583136,female,65.0,torso,unknown,benign,0,0
1,ISIC_4419280,IP_1362494,female,50.0,head/neck,unknown,benign,0,0
2,ISIC_9042665,IP_4391034,female,55.0,torso,unknown,benign,0,0
3,ISIC_9189319,IP_1800426,female,40.0,upper extremity,unknown,benign,0,0
4,ISIC_7207496,IP_7817798,female,30.0,torso,nevus,benign,0,0
...,...,...,...,...,...,...,...,...,...
33121,ISIC_2008230,IP_7160012,male,60.0,torso,unknown,benign,0,4
33122,ISIC_7658268,IP_7279968,male,45.0,torso,unknown,benign,0,4
33123,ISIC_6749926,IP_0097257,female,65.0,upper extremity,unknown,benign,0,4
33124,ISIC_5148638,IP_9453080,female,60.0,lower extremity,unknown,benign,0,4


In [26]:
def open_function(parent:Path, mode:str = "RGB", resize:int = None,):
    def open_image(image_id:str):
        img = Image.open(parent/f"{image_id}.png").convert(mode)
        if resize:
            img = img.resize((resize,resize))
        return img
    return open_image

In [27]:
df.kfold.value_counts()

0    6626
4    6625
3    6625
2    6625
1    6625
Name: kfold, dtype: int64

In [28]:
open_rgb = open_function(TRAIN_DIR,mode = "RGB")

In [29]:
from torchvision import transforms as trf

In [30]:
class siimData(Dataset):
    def __init__(self,df,path,aug):
        self.df = df
        self.indices = df.index
        self.path = Path(path)
        self.aug = aug
        
    def __len__(self,):return len(self.df)
        
    def __getitem__(self,idx):
        row = dict(self.df.loc[self.indices[idx]])
        image_name = row["image_name"]
        target = row["target"]
        img = open_rgb(image_name)
        
        arr = self.aug(image = np.array(img))["image"]
        return np.moveaxis(arr,[0,1,2],[1,2,0]), target

### Test dataloader

In [1]:
dl = DataLoader(siimData(df,TRAIN_DIR,aug = get_aug(True)),batch_size=4)
x,y = next(iter(dl))
x.shape,y.shape

NameError: name 'DataLoader' is not defined

In [None]:
y

In [17]:
from pytorch_lightning.metrics import Accuracy,F1
from pytorch_lightning.callbacks import 

SyntaxError: invalid syntax (<ipython-input-17-7880f9eff42b>, line 2)

In [None]:
class EfficientNet(LightningModule):
    def __init__(self,tag='efficientnet-b0',batch_size=16):
        super(EfficientNet, self).__init__()
        self.batch_size = batch_size
        self.base_model = efficientnet_pytorch\
            .EfficientNet\
            .from_pretrained(tag)
            
        self.base_model._fc = nn.Linear(
            in_features=1280, 
            out_features=1, 
            bias=True
        )
        self.acc = Accuracy(1)
        self.f1 = F1(1)
        
        self.crit = nn.BCEWithLogitsLoss()
        
    def forward(self, image):
        out = self.base_model(image)
        return out
    
    def training_step(self,batch,batch_idx):
        image, targets = batch
        out = self(image)
        loss = self.crit(out, targets.view(-1, 1).type_as(out))
        log = {"f1":self.f1(out,targets),"acc":self.acc(out,targets)}
        return {"loss":loss,"log":log}
    
    def validation_step(self,batch,batch_idx):
        image, targets = batch
        out = self(image)
        loss = self.crit(out, targets.view(-1, 1).type_as(out))
        log = {"f1":self.f1(out,targets),"acc":self.acc(out,targets)}
        return {"loss":loss,"log":log}
    
    def get_sampler(self,ds,shuffle):
        # required for TPU support
        sampler = torch.utils.data.distributed.DistributedSampler(
                    ds,
                    num_replicas=xm.xrt_world_size(),
                    rank=xm.get_ordinal(),
                    shuffle=True
                )
        return sampler
    
    def train_dataloader(self):
        aug = get_aug(True)
        ds = siimData(self.train_df,TRAIN_DIR,aug = aug)
        sampler = self.get_sampler(ds,shuffle=True)
        dl = DataLoader(ds,sampler = sampler,batch_size = self.batch_size,num_workers = 2)
        return dl
    
    def val_dataloader(self):
        aug = get_aug(False)
        ds = siimData(self.valid_df,TRAIN_DIR,aug = aug)
        sampler = self.get_sampler(ds,shuffle=False)
        dl = DataLoader(ds,sampler = sampler,batch_size = self.batch_size,num_workers = 2)
        return dl
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(),1e-4)
    
    @classmethod
    def from_fold(cls,tag,df,kfold,batch_size = 16):
        obj = cls(tag,batch_size=batch_size)
        obj.train_df = df[df["kfold"]!=kfold].sample(frac=1.).reset_index(drop=True)
        obj.valid_df = df[df["kfold"]==kfold].sample(frac=1.).reset_index(drop=True)
        return obj

In [None]:
net = EfficientNet.from_fold('efficientnet-b0',df = df,kfold = 0,batch_size = 64)

## Trainer

In [None]:
trainer = Trainer(max_epochs=20,tpu_cores=8,distributed_backend="ddp",replace_sampler_ddp = False)

In [None]:
trainer.fit(net)

In [None]:
# init model here
# MX = EfficientNet()

In [None]:
def train():
    training_data_path = "../input/siic-isic-224x224-images/train/"
    df = pd.read_csv("/kaggle/working/train_folds.csv")
    device = xm.xla_device()
    epochs = 5
    train_bs = 32
    valid_bs = 16
    fold = 0

    df_train = df[df.kfold != fold].reset_index(drop=True)
    df_valid = df[df.kfold == fold].reset_index(drop=True)

    model = MX.to(device)

    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    train_aug = albumentations.Compose(
        [
            albumentations.Normalize(
                mean, 
                std, 
                max_pixel_value=255.0, 
                always_apply=True
            ),
            albumentations.ShiftScaleRotate(
                shift_limit=0.0625, 
                scale_limit=0.1, 
                rotate_limit=15
            ),
            albumentations.Flip(p=0.5)
        ]
    )

    valid_aug = albumentations.Compose(
        [
            albumentations.Normalize(
                mean, 
                std, 
                max_pixel_value=255.0,
                always_apply=True
            )
        ]
    )

    train_images = df_train.image_name.values.tolist()
    train_images = [
        os.path.join(training_data_path, i + ".png") for i in train_images
    ]
    train_targets = df_train.target.values

    valid_images = df_valid.image_name.values.tolist()
    valid_images = [
        os.path.join(training_data_path, i + ".png") for i in valid_images
    ]
    valid_targets = df_valid.target.values

    train_loader = ClassificationDataLoader(
        image_paths=train_images,
        targets=train_targets,
        resize=None,
        augmentations=train_aug,
    ).fetch(
        batch_size=train_bs, 
        drop_last=True, 
        num_workers=4, 
        shuffle=True, 
        tpu=True
    )

    valid_loader = ClassificationDataLoader(
        image_paths=valid_images,
        targets=valid_targets,
        resize=None,
        augmentations=valid_aug,
    ).fetch(
        batch_size=valid_bs, 
        drop_last=False, 
        num_workers=2, 
        shuffle=False, 
        tpu=True
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        patience=3,
        threshold=0.001,
        mode="min"
    )

    es = EarlyStopping(patience=5, mode="min", tpu=True)
    eng = Engine(model, optimizer, device=device, use_tpu=True, tpu_print=25)

    for epoch in range(epochs):
        train_loss = eng.train(train_loader)
        valid_loss = eng.evaluate(valid_loader)
        xm.master_print(f"Epoch = {epoch}, LOSS = {valid_loss}")
        scheduler.step(valid_loss)

        es(valid_loss, model, model_path=f"model_fold_{fold}.bin")
        if es.early_stop:
            xm.master_print("Early stopping")
            break
        gc.collect()

In [None]:
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = train()

In [None]:
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

## Recalculate mean and std

In [None]:
# def img2arr(path):
#     path = Path(path)
#     def open_img(img_id):
#         return np.array(Image.open(path/f"{img_id}.png"))/255
#     return open_img

# means = []
# stds = []
# images = []
# image_names = list(df.image_name)
# open_train = img2arr(TRAIN_DIR)
# with torch.no_grad():
#     for i in trange(len(image_names)):
#         img = image_names[i]
#         images.append(open_train(img))
#         if len(images)==320:
#             concatenated = np.concatenate(images,axis=0).reshape(-1,3)
#             means.append(concatenated.mean(0))
#             stds.append(concatenated.std(0))
#             images = []

# np.stack(means,axis=0).mean(0),np.stack(stds,axis=0).mean(0)