<a href="https://colab.research.google.com/github/abbyambita/Diagnosing-COVID-from-CT-Scan-Images/blob/main/trial_saganv2_kaggle.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://github.com/rosinality/sagan-pytorch/

In [1]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [2]:
import os 

os.chdir("/content/gdrive/My Drive")

!ls  '/content/gdrive/My Drive/CS 284 Mini-Project/Code/output_result/sagan/kaggle'

%cd "/content/gdrive/My Drive/CS 284 Mini-Project/Code/"

'batch_size=64,epoch=2000'  'batch_size=64,epoch=3500'
/content/gdrive/.shortcut-targets-by-id/1eVFVz23F6ROX0s10Oe3tT9HVzr502iW2/CS 284 Mini-Project/Code


In [3]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import PIL
import glob
import xml.etree.ElementTree as ET
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
from IPython.display import HTML
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR
from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output
from scipy.stats import truncnorm
%matplotlib inline
plt.rcParams['image.interpolation'] = 'nearest'

import torch

from torch import nn
from torch.nn import init
from torch.nn import functional as F

import functools
from torch.autograd import Variable

In [4]:
def init_linear(linear):
    init.xavier_uniform_(linear.weight)
    linear.bias.data.zero_()


def init_conv(conv, glu=True):
    init.xavier_uniform_(conv.weight)
    if conv.bias is not None:
        conv.bias.data.zero_()


class SpectralNorm:
    def __init__(self, name):
        self.name = name

    def compute_weight(self, module):
        weight = getattr(module, self.name + '_orig')
        u = getattr(module, self.name + '_u')
        size = weight.size()
        weight_mat = weight.contiguous().view(size[0], -1)
        with torch.no_grad():
            v = weight_mat.t() @ u
            v = v / v.norm()
            u = weight_mat @ v
            u = u / u.norm()
        sigma = u @ weight_mat @ v
        weight_sn = weight / sigma
        # weight_sn = weight_sn.view(*size)

        return weight_sn, u

    @staticmethod
    def apply(module, name):
        fn = SpectralNorm(name)

        weight = getattr(module, name)
        del module._parameters[name]
        module.register_parameter(name + '_orig', weight)
        input_size = weight.size(0)
        u = weight.new_empty(input_size).normal_()
        module.register_buffer(name, weight)
        module.register_buffer(name + '_u', u)

        module.register_forward_pre_hook(fn)

        return fn

    def __call__(self, module, input):
        weight_sn, u = self.compute_weight(module)
        setattr(module, self.name, weight_sn)
        setattr(module, self.name + '_u', u)


def spectral_norm(module, name='weight'):
    SpectralNorm.apply(module, name)

    return module


def spectral_init(module, gain=1):
    init.kaiming_uniform_(module.weight, gain)
    if module.bias is not None:
        module.bias.data.zero_()

    return spectral_norm(module)


def leaky_relu(input):
    return F.leaky_relu(input, negative_slope=0.2)


class SelfAttention(nn.Module):
    def __init__(self, in_channel, gain=1):
        super().__init__()

        self.query = spectral_init(nn.Conv1d(in_channel, in_channel // 8, 1),
                                   gain=gain)
        self.key = spectral_init(nn.Conv1d(in_channel, in_channel // 8, 1),
                                 gain=gain)
        self.value = spectral_init(nn.Conv1d(in_channel, in_channel, 1),
                                   gain=gain)

        self.gamma = nn.Parameter(torch.tensor(0.0))

    def forward(self, input):
        shape = input.shape
        flatten = input.view(shape[0], shape[1], -1)
        query = self.query(flatten).permute(0, 2, 1)
        key = self.key(flatten)
        value = self.value(flatten)
        query_key = torch.bmm(query, key)
        attn = F.softmax(query_key, 1)
        attn = torch.bmm(value, attn)
        attn = attn.view(*shape)
        out = self.gamma * attn + input

        return out


class ConditionalNorm(nn.Module):
    def __init__(self, in_channel, n_class):
        super().__init__()

        self.bn = nn.BatchNorm2d(in_channel, affine=False)
        self.embed = nn.Embedding(n_class, in_channel * 2)
        self.embed.weight.data[:, :in_channel] = 1
        self.embed.weight.data[:, in_channel:] = 0

    def forward(self, input, class_id):
        out = self.bn(input)
        embed = self.embed(class_id)
        gamma, beta = embed.chunk(2, 1)
        gamma = gamma.unsqueeze(2).unsqueeze(3)
        beta = beta.unsqueeze(2).unsqueeze(3)
        out = gamma * out + beta

        return out


class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size=[3, 3],
                 padding=1, stride=1, n_class=None, bn=True,
                 activation=F.relu, upsample=True, self_attention=False):
        super().__init__()

        self.conv = spectral_init(nn.Conv2d(in_channel, out_channel,
                                            kernel_size, stride, padding,
                                            bias=False if bn else True))

        self.upsample = upsample
        self.activation = activation
        self.bn = bn
        if bn:
            self.norm = ConditionalNorm(out_channel, n_class)

        self.self_attention = self_attention
        if self_attention:
            self.attention = SelfAttention(out_channel, 1)

    def forward(self, input, class_id=None):
        out = input
        if self.upsample:
            out = F.upsample(out, scale_factor=2)

        out = self.conv(out)

        if self.bn:
            out = self.norm(out, class_id)

        if self.activation is not None:
            out = self.activation(out)

        if self.self_attention:
            out = self.attention(out)

        return out


class Generator(nn.Module):
    def __init__(self, code_dim=100, n_class=2):
        super().__init__()

        self.lin_code = spectral_init(nn.Linear(code_dim, 4 * 4 * 512))
        self.conv = nn.ModuleList([ConvBlock(512, 512, n_class=n_class),
                                   ConvBlock(512, 512, n_class=n_class),
                                   ConvBlock(512, 512, n_class=n_class,
                                             self_attention=True),
                                   ConvBlock(512, 256, n_class=n_class),
                                   ConvBlock(256, 128, n_class=n_class)])

        self.colorize = spectral_init(nn.Conv2d(128, 3, [3, 3], padding=1))

    def forward(self, input, class_id):
        out = self.lin_code(input)
        out = F.relu(out)
        out = out.view(-1, 512, 4, 4)

        for conv in self.conv:
            out = conv(out, class_id)

        out = self.colorize(out)

        return F.tanh(out)


class Discriminator(nn.Module):
    def __init__(self, n_class=2):
        super().__init__()

        def conv(in_channel, out_channel, stride=2,
                 self_attention=False):
            return ConvBlock(in_channel, out_channel, stride=stride,
                             bn=False, activation=leaky_relu,
                             upsample=False, self_attention=self_attention)

        self.conv = nn.Sequential(conv(3, 128),
                                  conv(128, 256),
                                  conv(256, 512, stride=1,
                                       self_attention=True),
                                  conv(512, 512),
                                  conv(512, 512),
                                  conv(512, 512))

        self.linear = spectral_init(nn.Linear(512, 1))

        self.embed = nn.Embedding(n_class, 512)
        self.embed.weight.data.uniform_(-0.1, 0.1)
        self.embed = spectral_norm(self.embed)

    def forward(self, input, class_id):
        out = self.conv(input)
        out = out.view(out.size(0), out.size(1), -1)
        out = out.sum(2)
        out_linear = self.linear(out).squeeze(1)
        embed = self.embed(class_id)
        prod = (out * embed).sum(1)

        return out_linear + prod

In [5]:
from tqdm import tqdm
import numpy as np
import glob
import os
from PIL import Image

import argparse

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

In [6]:
batch=64
steps = 100
code=128
lr_g=1e-4
lr_d=4e-4
n_d=1
model='dcgan'
path='revised-kaggle-validation/train'
n_class = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose(
    [
        transforms.Resize(128),
        transforms.CenterCrop(128),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag


def sample_data(path, batch_size):
    
    #dataroot = "revised-git/train"

    dataset = datasets.ImageFolder(path, transform=transform)
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4)
    
    return loader
    #loader = iter(loader)

    # while True:
    #     try:
    #         yield next(loader)

    #     except StopIteration:
    #         loader = DataLoader(
    #             dataset, shuffle=True, batch_size=batch_size, num_workers=4
    #         )
    #         loader = iter(loader)
    #         yield next(loader)


In [7]:
num_epochs = 2000

dir = "output_result/sagan/kaggle/batch_size=64,epoch="+str(num_epochs)
os.makedirs(dir, exist_ok=True)


def cuda(data):
    if torch.cuda.is_available():
        return data.cuda()
    else:
        return data

fixed_z = cuda(torch.randn(64, 100))

def denorm(x):
    out = (x + 1) / 2
    return out.clamp_(0, 1)

In [8]:
def train(n_class, generator, discriminator):
    dataset = sample_data(path,batch)
    #pbar = tqdm(range(iter), dynamic_ncols=True)

    requires_grad(generator, False)
    requires_grad(discriminator, True)

    preset_code = torch.randn(n_class * 5, code).to(device)

    disc_loss_val = 0
    gen_loss_val = 0

    for epoch in range(num_epochs):
      for i, data in enumerate(dataset):
        discriminator.zero_grad()
        #real_image, label = next(dataset)
        real_image = data[0]
        label = data[1]
        real_image = real_image.to(device)
        label = label.to(device)

        b_size = real_image.size(0)
        

        fake_image = generator(
            torch.randn(b_size, code).to(device), label.to(device)
        )
        
        fake_predict = discriminator(fake_image, label)
        real_predict = discriminator(real_image, label)
        loss = F.relu(1 + fake_predict).mean()

        loss = loss + F.relu(1 - real_predict).mean()
        disc_loss_val = loss.detach().item()
        loss.backward()
        d_optimizer.step()

        generator.zero_grad()
        requires_grad(generator, True)
        requires_grad(discriminator, False)
        input_class = torch.multinomial(
            torch.ones(n_class), batch, replacement=True
        ).to(device)
        fake_image = generator(
            torch.randn(batch, code).to(device), input_class
        )
        predict = discriminator(fake_image, input_class)
        loss = -predict.mean()
        gen_loss_val = loss.detach().item()
        loss.backward()
        g_optimizer.step()
        requires_grad(generator, False)
        requires_grad(discriminator, True)

        # if (epoch + 1) % 2 == 0:
        #     generator.train(False)
        #     input_class = torch.arange(n_class).long().repeat(5).to(device)
        #     fake_image = generator(preset_code, input_class)
        #     generator.train(True)
        #     utils.save_image(
        #         fake_image.cpu().data,
        #         #f'sample/{str(i + 1).zfill(7)}.png',
        #         f'output_result/sagan/batch_size=64,epoch=1000/{str(i + 1).zfill(7)}.png',
        #         nrow=n_class,
        #         normalize=True,
        #         range=(-1, 1),
        #     )

        print("Epoch "+str(epoch)+"Dis "+str(disc_loss_val))

        if (epoch + 1) % (50) == 0:
            input_class = torch.arange(n_class).long().repeat(5).to(device)
            fake_image = generator(preset_code, input_class)
            utils.save_image(
                fake_image.cpu().data,
                #f'sample/{str(i + 1).zfill(7)}.png',
                f'output_result/sagan/kaggle/batch_size=64,epoch=2000/{str(epoch + 1).zfill(7)}.png',
                nrow=n_class,
                normalize=True,
                range=(-1, 1),
            )
        
        if (epoch + 1) % (100) == 0:
            torch.save(generator, dir+"/generator_epoch_"+str(epoch)+".pth")
            torch.save(discriminator, dir+"/discriminator_epoch_"+str(epoch)+".pth")

        # if (epoch + 1) % (500) == 0:
        #     generate_images(epoch, 2000)
        

        # if (i + 1) % 10000 == 0:
        #     no = str(i + 1).zfill(7)
        #     torch.save(generator.state_dict(), f'checkpoint/generator_{no}.pt')
        #     torch.save(discriminator.state_dict(), f'checkpoint/discriminator_{no}.pt')
        #     torch.save(g_optimizer.state_dict(), f'checkpoint/gen_optimizer_{no}.pt')
        #     torch.save(d_optimizer.state_dict(), f'checkpoint/dis_optimizer_{no}.pt')

        # pbar.set_description(
        #     (f'{i + 1}; G: {gen_loss_val:.5f};' f' D: {disc_loss_val:.5f}')
        # )

In [None]:
if __name__ == '__main__':

    n_class = len(glob.glob(os.path.join(path, '*/')))
    print(n_class)

    # if model == 'dcgan':
    #     from model import Generator, Discriminator

    # elif model == 'resnet':
    #     from model_resnet import Generator, Discriminator

    generator = Generator(code, n_class).to(device)
    discriminator = Discriminator(n_class).to(device)

    g_optimizer = optim.Adam(generator.parameters(), lr=lr_g, betas=(0, 0.9))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0, 0.9))
    train(n_class, generator, discriminator)

2




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch 234Dis 0.9830120801925659
Epoch 234Dis 1.171246886253357
Epoch 234Dis 1.4417893886566162
Epoch 234Dis 1.6567152738571167
Epoch 234Dis 0.9655657410621643
Epoch 234Dis 1.327530860900879
Epoch 234Dis 0.9653356075286865
Epoch 234Dis 1.2037211656570435
Epoch 234Dis 1.6137945652008057
Epoch 234Dis 1.1045794486999512
Epoch 234Dis 0.7062118649482727
Epoch 234Dis 0.8350886106491089
Epoch 234Dis 1.6926193237304688
Epoch 234Dis 1.2528218030929565
Epoch 234Dis 1.6323498487472534
Epoch 234Dis 1.4344942569732666
Epoch 234Dis 1.6434050798416138
Epoch 235Dis 1.4453295469284058
Epoch 235Dis 1.529765248298645
Epoch 235Dis 1.0899845361709595
Epoch 235Dis 1.1845078468322754
Epoch 235Dis 1.1204642057418823
Epoch 235Dis 1.2729911804199219
Epoch 235Dis 1.6143794059753418
Epoch 235Dis 0.9258060455322266
Epoch 235Dis 1.232898235321045
Epoch 235Dis 0.9951980710029602
Epoch 235Dis 1.1392288208007812
Epoch 235Dis 1.622911810874939
Epoch 235Dis

In [9]:
generator = torch.load("output_result/sagan/kaggle/batch_size=64,epoch=2000/generator_epoch_299.pth")

In [12]:
def generate_images(epoch, batch_size):

  for j in range(20):
    input_class = torch.arange(n_class).long().repeat(batch_size).to(device)
    print(len(input_class))

    preset_code = torch.randn(n_class * batch_size, code).to(device)
    print(len(preset_code))
    
    fake_image = generator(preset_code, input_class)
    print(len(fake_image))

    ncv = "sagan_output_images/kaggle/batch_size=64,epoch="+str(epoch)+"/noncovid"
    cv = "sagan_output_images/kaggle/batch_size=64,epoch="+str(epoch)+"/covid"

    os.makedirs(ncv, exist_ok=True)
    os.makedirs(cv, exist_ok=True)

    for i, img in enumerate(fake_image):
      if(input_class[i]==0):
        utils.save_image(
                  img.cpu().data,
                  #f'sample/{str(i + 1).zfill(7)}.png',
                  ncv+'/b'+str(j)+'_fake_img_'+str(i)+'.png'
              )
      else:
        utils.save_image(
                  img.cpu().data,
                  #f'sample/{str(i + 1).zfill(7)}.png',
                  cv+'/b'+str(j)+'_fake_img_'+str(i)+'.png'
              )

generate_images(300, 100)

200
200
200




200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
