In [None]:
import pickle
import torch
import torch.nn as nn
import logging
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

from dcgan_with_embeddings import Generator, Discriminator, PathFoundationModel
from dataset_loader import CustomDataset
from IPython.display import HTML
from torch.utils.data import DataLoader

In [None]:
batch_size = 128
image_size = 224
nc = 3
nz = 100
ngf = 64
ndf = 64
num_epochs = 10
lr = 0.0002
beta1 = 0.5

DATASET_DIR = "../datasets/merged_embeddings/merged_dataset.pkl"
SLIDE_DIR = "../datasets/wsi"
DEVICE = torch.device("mps" if torch.mps.is_available() else "cpu")
print("Using device: ", DEVICE)

logging.basicConfig(filename='gan.log', level=logging.DEBUG, format='%(asctime)s %(message)s')

logging.info("Batch size: %d", batch_size)
logging.info("Image size: %d", image_size)
logging.info("Number of channels: %d", nc)
logging.info("Size of latent vector: %d", nz)
logging.info("Number of generator filters: %d", ngf)
logging.info("Number of discriminator filters: %d", ndf)
logging.info("Number of epochs: %d", num_epochs)
logging.info("Learning rate: %f", lr)
logging.info("Beta1: %f", beta1)
logging.info("Dataset directory: %s", DATASET_DIR)
logging.info("Slide directory: %s", SLIDE_DIR)
logging.info("Device: %s", DEVICE)

In [None]:
with open(DATASET_DIR, "rb") as f:
    train_dataset = pickle.load(f)

idx =100
print("Length of train dataset: ", len(train_dataset))
print("Train dataset keys: ", train_dataset[idx].keys())
print("Slide name: ", train_dataset[idx]["slide_name"])
print("Embedding vector shape: ", train_dataset[idx]["embedding_vector"].shape)

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
logging.info("Transforms: %s", transform)

train_data = CustomDataset(train_dataset, slide_dir=SLIDE_DIR, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
logging.info("Length of train loader: %d", len(train_loader))

real_batch = next(iter(train_loader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(DEVICE)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()

In [4]:
netG = Generator(nz, ngf, nc).to(DEVICE)
logging.info("Generator: %s", netG)

In [5]:
netD = Discriminator().to(DEVICE)
logging.info("Discriminator: %s", netD)

In [None]:
netGoogle = PathFoundationModel(model_name="google/path-foundation")
logging.info("PathFoundationModel: %s", netGoogle)

In [7]:
criterion = nn.BCELoss()
logging.info("Criterion: %s", criterion)

fixed_noise = torch.randn(64, nz, 1, 1, device=DEVICE)

real_label = 1.
fake_label = 0.
logging.info("Real label: %f", real_label)
logging.info("Fake label: %f", fake_label)


optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
logging.info("Optimizer D: %s", optimizerD)
logging.info("Optimizer G: %s", optimizerG)

In [None]:
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")

for epoch in range(num_epochs):

    for i, data in enumerate(train_loader, 0):
        netD.zero_grad()
        real_cpu = data[1].to(DEVICE)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=DEVICE)
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        noise = torch.randn(b_size, nz, 1, 1, device=DEVICE)
        fake = netG(noise)
        fake_embedding = netGoogle.inference(fake.cpu().detach().numpy())
        fake_embedding = fake_embedding.to(DEVICE)
        label.fill_(fake_label)
        output = netD(fake_embedding.detach())
        output = output.view(-1)
        
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()


        netG.zero_grad()
        label.fill_(real_label)  
        output = netD(fake_embedding).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(train_loader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            
            logging.info('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                    % (epoch, num_epochs, i, len(train_loader),
                         errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(train_loader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
ani.save("animation.html", writer="html")
HTML(ani.to_jshtml())

In [None]:
real_batch = next(iter(train_loader))

plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(DEVICE)[:64], padding=5, normalize=True).cpu(),(1,2,0)))


plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

In [12]:
torch.save(netG.state_dict(), "generator.pth")
torch.save(netD.state_dict(), "discriminator.pth")