In [None]:
import sys
sys.path.insert(0,'../src/')

import os

from data.make_dataset import *
from visualization.visualize import *
from models.models import *

%load_ext autoreload
%autoreload 1

In [None]:
import torch
import numpy

if torch.cuda.is_available():
    prop = torch.cuda.get_device_properties(0)
    print(f"Using {prop.name} with {np.round(prop.total_memory/1e9, 2)}GB of RAM")
else:
    print("Using CPU")

# Load Data

In [None]:
channel_size = 3
image_size = 64
batch_size = 128
workers = 10

In [None]:
%%time
data_loader = make_dataloader("../data/raw/planctons_original", image_size, batch_size, workers)

## Visualize Data

In [None]:
print_samples(data_loader)

# Load Models

In [None]:
latent_size = 100
feature_map_size = 64

In [None]:
beta = 0.5
lr = 0.0002
num_epochs = 2

In [None]:
g = DCGAN_Generator(latent_size, feature_map_size, channel_size)
d = DCGAN_Discriminator(feature_map_size, channel_size)

gan = GAN(d, g, lr, beta)

# Train Model

In [None]:
%%time
gan.train(data_loader, 1)

## Persist model

In [None]:
outdir = os.path.join('..', 'models', 'version_0')
outfile = 'gan.pkl'

gan.persist(outdir, outfile)

## Random Search

## Visualize training stats

In [None]:
g_loss = [i[1] for i in gan.loss['discriminator']]
d_loss = [i[1] for i in gan.loss['generator']]

In [None]:
plot_loss(d_loss, g_loss)#, img_path=f'../reports/model_analysis/loss_{5}.png')

In [None]:
generator_progress(gan.img_list)#, gif_path=f'../reports/model_analysis/progress_{5}.png')

In [None]:
compare_fake_real(data_loader, gan.device, gan.img_list, img_path=f'../reports/model_analysis/comparison_{5}.png')

## Use models

In [None]:
fake_batch = gan.generate(torch.randn(64, gan.generator.latent_size, 1, 1, device=gan.device))
print_batch_images([fake_batch], gan.device)

In [None]:
img_data = next(iter(data_loader))[0]
gan.predict_discriminator(img_data)

# Load trained gan

In [None]:
gan = GAN.load(os.path.join(outdir, outfile))