In [None]:
import os

import torch
import torchvision
from torch.utils.data import DataLoader
from torchsummary import summary

import pickle
from Split_data import random_split

from WGAN_GP import Generator, Discriminator
from Train_WGAN_GP import train_WGANGP

random_seed = 42
torch.manual_seed(random_seed);

In [None]:
# Parameters
INPUT_LATENT = 128
batch_size = 128
N_CORES = os.cpu_count()

In [None]:
# load dataset
data_file_path = os.path.join("./data", "stop_speed.pkl")

# Load the data from the file
with open(data_file_path, "rb") as data_file:
    reduced_data = pickle.load(data_file)

train_ds, val_ds, test_ds = random_split(reduced_data)

train_loader = DataLoader(
    train_ds, 
    batch_size, 
    shuffle=True, 
    num_workers=int(N_CORES/2), 
    pin_memory=True
)

val_loader = DataLoader(
    test_ds, 
    batch_size*2, 
    num_workers=int(N_CORES/2), 
    pin_memory=True
)

In [None]:
# Set compute devices
device_D = torch.device('cuda')
device_G = torch.device('cuda')

In [None]:
# load generator model
netG = Generator()
summary(netG, input_size = (INPUT_LATENT, 1, 1), device = 'cpu')

In [None]:
# load discriminator model
netD = Discriminator()
summary(netD, input_size = (3, 32, 32), device = 'cpu')

In [None]:
# set folder to save model checkpoints 
model_folder = os.path.abspath('./trained_models/WGAN_GP')
if not os.path.exists(model_folder):
    os.mkdir(model_folder)
    
# set folder to save generated images 
img_folder = os.path.abspath('./Generated_imgs')
if not os.path.exists(img_folder):
    os.mkdir(img_folder)

In [None]:
# Load last saved models (if any)
check_point_path =  './trained_models/WGAN_GP/model_snapshots.pth' 

if os.path.exists(check_point_path):
    checkpoint = torch.load(check_point_path)

    inital_epoch = checkpoint['epoch']

    netG.load_state_dict(checkpoint['netG_state_dict'])
    netD.load_state_dict(checkpoint['netD_state_dict'])

In [None]:
# Move models to GPU
netG = netG.to(device_G)
netD = netD.to(device_D)

In [None]:
# Train WGAN-GP
inital_epoch = 0

train_WGANGP(train_loader, val_loader, netD, netG, inital_epoch)