In [2]:
import data
import torch
import torch.nn as nn

In [4]:
%cd

/home/kacper


In [161]:
class Generator(nn.Module):
    def __init__(self, latent_size, image_size=64):
        super(Generator, self).__init__()
        self.latent_size = latent_size
        
        self.mlp = nn.Sequential(
            nn.Linear(latent_size, latent_size, bias=False),
            nn.ReLU(),
            nn.Linear(latent_size, latent_size, bias=False),
            nn.ReLU(),
        )
                
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=latent_size, 
                out_channels=latent_size//2, 
                kernel_size=4, 
                stride=4, 
                padding=0, 
                bias=False),
            nn.ReLU(True),
            nn.ConvTranspose2d(
                in_channels=latent_size//2, 
                out_channels=latent_size//4, 
                kernel_size=4, 
                stride=4, 
                padding=0, 
                bias=False),
            nn.ReLU(True),
            nn.ConvTranspose2d(
                in_channels=latent_size//4, 
                out_channels=3, 
                kernel_size=4, 
                stride=4, 
                padding=0, 
                bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.mlp(x)
        x = x.reshape((-1, self.latent_size, 1, 1))
        x = self.conv(x)
        return x
    
class Discriminator(nn.Module):
    def __init__(self, latent_size, image_size=64):
        super(Discriminator, self).__init__()
        self.latent_size = latent_size
        
        self.mlp = nn.Sequential(
            nn.Linear(latent_size, latent_size, bias=False),
            nn.ReLU(),
            nn.Linear(latent_size, 1, bias=False),
            nn.Sigmoid(),
        )
                
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels=3, 
                out_channels=latent_size//4, 
                kernel_size=4, 
                stride=4, 
                padding=0, 
                bias=False),      
            nn.ReLU(True),
            nn.Conv2d(
                in_channels=latent_size//4, 
                out_channels=latent_size//2, 
                kernel_size=4, 
                stride=4, 
                padding=0, 
                bias=False),
            nn.ReLU(True),
            nn.Conv2d(
                in_channels=latent_size//2, 
                out_channels=latent_size, 
                kernel_size=4, 
                stride=4, 
                padding=0, 
                bias=False),
            nn.ReLU(True),
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.reshape(-1, self.latent_size)
        x = self.mlp(x)        
        return x.reshape(-1)

In [162]:
def backup_to_ram(model):
    from copy import deepcopy
    return deepcopy(model).cpu()

class EarlyStopperGAN:
    def __init__(self, patience = 3, backup_method=backup_to_ram):
        self.patience = patience
        self.current = 0
        
        self.backup_method = backup_method
        
        self.best_backup = None
        self.best_loss_g = float("inf")
        self.best_loss_d = float("inf")

    def should_continue(self, loss_g, loss_d):
        if self.best_loss_g + 1e-6 > loss_g or self.best_loss_d + 1e-6 > loss_d:
            self.current = 0
            self.best_loss_g = loss_g
            self.best_loss_d = loss_d
            return True
        
        self.current += 1
        
        if self.current >= self.patience:
            return False
        return True
    
class EarlyStop(Exception):
    pass

In [163]:
from DeepLearning.Project3.data import load_dataloader_preprocess
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [166]:
def train_run(run):
    path = f"experiments_gan/simpleconvgan_{run}_"
    import os
    try:
        if os.stat(path + "report.json").st_size != 0:
            print("Report exists already for " + path[:-1] + ". Skipping...")
            return
    except OSError:
        pass
    
    
    
    try:
        latent_size = 128
        batch_size = 256

        early_stopper = EarlyStopperGAN(20)

        generator = Generator(latent_size).to(device)
        discriminator = Discriminator(latent_size).to(device)

        generator_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
        discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)

        loss = nn.BCELoss()

        epochs = 10

        trajectory = []
        i = 0

        for epoch in range(epochs):
            for true_data, _ in load_dataloader_preprocess(bs = batch_size):
                true_data = true_data.to(device)
                cur_bs = true_data.shape[0]
                label = torch.zeros(cur_bs, device=device)

                # Train D on true data
                discriminator_optimizer.zero_grad()
                label.fill_(1.) # true data
                true_discriminator_out = discriminator(true_data)
                true_discriminator_loss = loss(true_discriminator_out, label)
                true_discriminator_loss.backward()
                true_accuracy = true_discriminator_out.mean().item()

                # Train D on fake data
                label.fill_(0.) # fake data

                ## Generate the fake data
                noise = torch.randn((cur_bs, latent_size), device=device)
                fake_data = generator(noise)

                ## Pass to D
                fake_discriminator_out = discriminator(fake_data)
                fake_discriminator_loss = loss(fake_discriminator_out, label)
                fake_discriminator_loss.backward(retain_graph=True)

                discriminator_total_loss = true_discriminator_loss.item() + fake_discriminator_loss.item()

                discriminator_optimizer.step()

                # Train G
                generator_optimizer.zero_grad()

                label.fill_(1.)  # fake labels are real for generator cost

                output = discriminator(fake_data).view(-1)
                generator_loss = loss(output, label)
                generator_loss.backward()
                generator_optimizer.step()
                
                i +=1
                if not early_stopper.should_continue(generator_loss.item(), discriminator_total_loss):
                    raise Exception()
                if i % 50 == 0:
                    print(generator_loss.item(), discriminator_total_loss)
                trajectory.append({
                    "epoch": epoch,
                    "iteration": i,
                    "d_loss": discriminator_total_loss, 
                    "g_loss": generator_loss.item()
                })
    except EarlyStop:
        print("early stop")
        pass
    
    

    import json
    with open(f"{path}report.json", "w") as report_file:
        json.dump(
            {
                "experiment_name": "simple_conv_gan",
                "run": run,
                "trajectory": trajectory,
                "best_loss_g": early_stopper.best_loss_g, 
                "best_loss_d": early_stopper.best_loss_d
            },
            report_file
        )

    torch.save(generator, path + "generator.pt")
    torch.save(discriminator, path + "discriminator.pt")

In [167]:
for run in range(1, 21):
    train_run(run)

Report exists already for experiments_gan/simpleconvgan_1. Skipping...
Report exists already for experiments_gan/simpleconvgan_2. Skipping...
1.8186657428741455 0.254852849728195
0.6931419968605042 0.6931603874154462
0.6898683309555054 0.6967872690293007
0.6978483200073242 0.6888187292552175
0.6933448314666748 0.6929519432597413
0.6934757232666016 0.6928252093494027
0.6932801604270935 0.6930148974061936
0.693203866481781 0.6930907964706421
0.6932120323181152 0.6931180277424573
0.6931849122047424 0.6931096315383911
0.6931692361831665 0.6931251287460327
0.6931977272033691 0.6930967569351196
0.693156898021698 0.6931374673731625
0.6931583881378174 0.69313911451286
0.6931576132774353 0.6931368121877314
0.6931557059288025 0.6931386590003967
0.6931471824645996 0.6931471824645996
0.693163275718689 0.6931311958472293
0.6931504011154175 0.6931439638137817
0.6931486129760742 0.693145751953125
0.6931627988815308 0.693131685256958
0.6931545734405518 0.6931397919543087
0.6931548118591309 0.693139553

0.6931517124176025 0.6931426525115967
0.6931471824645996 0.6931471824645996
0.6803456544876099 0.7065075635910034
0.7025401592254639 0.6915868880771541
0.8201166391372681 0.6208558231592178
0.7236900329589844 0.6822402407124173
0.9210761785507202 0.5902311913669109
0.6921043992042542 0.6942351171753671
0.9107023477554321 0.6404752172529697
0.7287144660949707 0.6874364484101534
0.7235857844352722 0.703990469686687
0.7862896919250488 0.6342044770717621
1.0572749376296997 0.6317061707377434
0.7764925956726074 0.6602591242990457
0.6878597736358643 0.7246655244380236
0.6651313900947571 0.7305306607449893
0.7663085460662842 0.6447775245178491
0.776637077331543 0.7202689200639725
0.6669250726699829 0.7496459260582924
0.7297681570053101 0.6640792330726981
0.6658740639686584 0.7216559167991363
1.085357666015625 0.5958950594067574
0.6991008520126343 0.8471646346151829
0.7774752974510193 0.6538216713815928
0.7527468204498291 0.6507613328776642
0.69398033618927 0.6923527940507483
0.693344235420227

0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471829302609
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471866555534
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824

0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931509971618652 0.693143367767334
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6696639060974121 0.7180449962615967
0.6934850215911865 0.6928448677062988
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.69918215274

0.6931471824645996 0.6931471852585684
0.6931471824645996 0.6931471880525364
0.6931471824645996 0.6931471829302609
0.6931471824645996 0.6931471941061362
0.6931471824645996 0.6931471866555521
0.6931471824645996 0.6931473175080924
0.6931471824645996 0.6931471852585678
0.6931471824645996 0.6931471861898908
0.6931471824645996 0.693147222511648
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471833959223
0.6931471824645996 0.693147200159741
0.6931471824645996 0.6931471996940868
0.6931471824645996 0.6931471996940921
0.6931471824645996 0.6931471861898908
0.6931471824645996 0.6931471917778325
0.6931471824645996 0.6931472118013176
0.6931471824645996 0.693147191312173
0.6931471824645996 0.6931472020224057
0.6931471824645996 0.6931472034193824
0.6931471824645996 0.6931471857242291
0.6931471824645996 0.6931471857242293
0.6931471824645996 0.6931471847929065
0.6931471824645996 0.6931472430007517
0.6931471824645996 0.6931471852585678
0.6931471824645996 0.6931480354096493
0.6931471824645

0.6931471824645996 0.6931471824645996
0.6931453943252563 0.6931489706039429
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931305527687073 0.6931638717651367
0.7259037494659424 0.6890489384531975
0.6940861940383911 0.692304866441134
0.6931622624397278 0.6931322813034058
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931474208831787 0.6931469440460205
0.6940258741378784 0.6924439668655396
0.6931471824645996 0.6931471829302609
0.6931473016738892 0.6931470632553101
0.6934535503387451 0.6928691864013672
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931586265563965 0.6931357979774475
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.69314718246

0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471829302609
0.6931471824645996 0.6931471857242297
0.6931471824645996 0.6931471838615837
0.6931471824645996 0.6931471852585684
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471950374748
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471829302609
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931474208831787 0.6931470036506653
0.6931471824645996 0.69314718572423
0.6931471824645996 0.6931471829302609
0.6931471824645996 0.6931472006254271
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.693147185258568
0.6931471824645996 0.6931471829302609
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471852585684
0.6931471824645996 0.6931471833959223
0.6931471824645996 0.6931471824645996
0.6931471824645

0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471829302609
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471829302609
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824

0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931473016738892 0.6931470632553101
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824645996 0.6931471824645996
0.6931471824

Exception: 