# [Chest X-ray pneumonia](https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia)
---


In [None]:
import time

import torch_xla.core.xla_model as xm

start = time.time()
device = xm.xla_device()
print(f"{time.time()-start:.0f}")

In [None]:
import torch_xla
# latest: '1.13'
# pinned: '1.11'
torch_xla.__version__

### Imports

In [None]:
from   collections import OrderedDict
import os
assert 'ISTPUVM' in os.environ, "Select TPU 1VM v3-8 in Settings > Accelerator from the right panel."
from   pathlib import Path
import sys

from   IPython.display import clear_output
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from   torch.utils.data import DataLoader
from   torch.utils.data.distributed import DistributedSampler
import torchvision
from   torchvision import transforms as T, datasets
from   torchvision.models import densenet121
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils

print(sys.version)
print(f"{torch.__version__=}")
print(f"{torchvision.__version__=}")
print(f"{torch_xla.__version__=}")
print(f"{os.environ['KAGGLE_DOCKER_IMAGE']=}")

In [None]:
data_path = Path('../input/chest-xray-pneumonia/chest_xray/chest_xray')

# https://pytorch.org/xla/release/1.13/index.html#torch_xla.distributed.xla_multiprocessing.MpSerialExecutor
# https://discuss.pytorch.org/t/what-does-xmp-mpserialexecutor-really-do/96150
# MX = xmp.MpModelWrapper(model)

train_transform = T.Compose([
    T.Resize((224, 224), antialias=True),
    T.RandomAffine(degrees=15, translate=(0.2, 0.2), scale=(0.5, 1.5)),
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

eval_transform = T.Compose([
    T.Resize((224,224), antialias=True),
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

train_data = datasets.ImageFolder(root=data_path/"train", transform=train_transform)
val_data = datasets.ImageFolder(root=data_path/"val", transform=eval_transform)

model = densenet121()
model.classifier = nn.Sequential(OrderedDict([
    ('fcl1', nn.Linear(1024,256)),
    ('dp1', nn.Dropout(0.3)),
    ('r1', nn.ReLU()),
    ('fcl2', nn.Linear(256,32)),
    ('dp2', nn.Dropout(0.3)),
    ('r2', nn.ReLU()),
    ('fcl3', nn.Linear(32,2)),
    ('out', nn.LogSoftmax(dim=1)),
]))

def _mp_fn(index):
    global model
    
    train_sampler = DistributedSampler(
        train_data,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True,
    )
    train_loader = DataLoader(
        train_data,
        batch_size=128,
        sampler=train_sampler,
        drop_last=False,
        num_workers=8,
    )
    
    val_sampler = DistributedSampler(
        val_data,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    val_loader = DataLoader(
        val_data,
        batch_size=8,
        sampler=val_sampler,
        drop_last=False,
        num_workers=8,
    )

    device = xm.xla_device()
    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    val_device_loader = pl.MpDeviceLoader(val_loader, device)
    
    # model = MX.to(device)
    model.to(device)
    loss_fn = nn.NLLLoss()
    optimizer = optim.Adadelta(model.parameters())
    
    epochs = 20
    for epoch in range(1, epochs+1):
        total = 0
        total_samples = 0
        model.train()
        for step, (data, target) in enumerate(train_device_loader):
            # print(f"\t[{xm.xla_real_devices([str(device)])[0]}] job started {step}", flush=True)
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            total += loss.item() * data.size(0)
            total_samples += data.size(0)
        train_loss_total = xm.mesh_reduce('total', total, np.sum)
        train_loss_total_samples = xm.mesh_reduce('total_samples', total_samples, np.sum)
        # xm.master_print(f"[{epoch=:>2}] {device} {xm.xla_real_devices([str(device)])[0]} loss: {accuracy:.6f}", flush=True)
        
        total = 0
        total_samples = 0
        acc = 0
        acc_count = 0
        model.eval()
        with torch.no_grad():
            for step, (data, target) in enumerate(val_device_loader):
                output = model(data)
                loss = loss_fn(output, target)
                total += loss.item() * data.size(0)
                total_samples += data.size(0)
                _, preds = torch.max(output, 1)
                acc += (preds == target).float().sum()
                acc_count += data.size(0)
        val_loss_total = xm.mesh_reduce('total', total, np.sum)
        val_loss_total_samples = xm.mesh_reduce('total_samples', total_samples, np.sum)
        val_acc = xm.mesh_reduce('acc', acc, np.sum)
        val_acc_count = xm.mesh_reduce('acc_count', acc_count, np.sum)
        
        train_loss = train_loss_total / train_loss_total_samples
        val_loss = val_loss_total / val_loss_total_samples
        val_accuracy = val_acc / val_acc_count
        xm.master_print(f"[{epoch=:>2}] {train_loss=:.6f} | {val_loss=:.6f} | {val_accuracy=:.2%}", flush=True)

xmp.spawn(_mp_fn, args=(), nprocs=8, start_method='fork')