In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [None]:
!mgenv/bin/pip install matplotlib tensorboard

In [None]:
torch.cuda.is_available()

In [None]:
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

def show(img):
    npimg = img.numpy()
    plt.imshow(img.permute(1, 2, 0), interpolation='nearest')

In [None]:
transform_stack = transforms.Compose([
    transforms.CenterCrop(2048),
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])
train = torchvision.datasets.ImageFolder(
    "/data/kaggle/paultimothymooney/chest-xray-pneumonia/train",
    transform=transform_stack,
)
val = torchvision.datasets.ImageFolder(
    "/data/kaggle/paultimothymooney/chest-xray-pneumonia/val",
    transform=transform_stack,
)

In [None]:
len(train), len(val)

In [None]:
train.class_to_idx, val.class_to_idx

In [None]:
show(train[0][0])

In [None]:
model = nn.Sequential(
    torchvision.models.resnet18(),
    nn.Linear(1000, 2),
).cuda()

In [None]:
from torch.utils.tensorboard import SummaryWriter
tb_writer = SummaryWriter("./logs/001")

In [None]:
train_loader = torch.utils.data.DataLoader(train, batch_size=4, shuffle=True, pin_memory=True)
val_x, val_y = next(iter(torch.utils.data.DataLoader(train, batch_size=16, shuffle=False)))
val_x = val_x.cuda()
val_y = val_y.cuda()

val_x.size()

In [None]:
tb_writer.add_images("val", val_x)
tb_writer.flush()

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
loss = nn.CrossEntropyLoss()

In [None]:
from tqdm import tqdm

In [None]:
def validate():
    with torch.no_grad():
        y_hat = model(val_x)
        val_loss = loss(y_hat, val_y)
    return val_loss

with tqdm(train_loader) as pbar:
    for i, batch in enumerate(pbar):

        # Run validation before every 10th batch
        if i % 10 == 0:
            val_loss = validate().item()
            tb_writer.add_scalar("loss/val", val_loss, global_step=i)

        x, y = batch
        x = x.cuda()
        y = y.cuda()

        if i == 0:
            tb_writer.add_images("batch/train", x)

        y_hat = model(x)
        batch_loss = loss(y_hat, y)

        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=batch_loss.item(), val_loss=val_loss)
        tb_writer.add_scalar("loss/train", batch_loss.item(), global_step=i)


val_loss = validate()
tb_writer.add_scalar("loss/val", val_loss, global_step=i)