In [None]:
import torch
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from utils.dataset import CocoDataset
from utils.plots import plot_l, plot_model_pred, plot_losses
from utils.models import EncoderDecoderGenerator, PatchGAN, save_model, load_model
from utils.metrics import evaluate_model
from utils.training import train_gan, load_losses

device = "cuda" if torch.cuda.is_available() else "cpu"
height, width = 256, 256
path_vm = "/home/default/coco/train/"
transform = transforms.Compose([
    transforms.Resize((height, width)),
    transforms.ToTensor(),
])
dataset = CocoDataset(root=path_vm, transform=transform)

In [None]:
torch.manual_seed(42)
test_size = int(0.2 * len(dataset))
train_size = len(dataset) - test_size
train, test = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train, batch_size=64, shuffle=True)
test_loader = DataLoader(test, batch_size=64, shuffle=False)

In [None]:
torch.manual_seed(42)
generator = EncoderDecoderGenerator().to(device)
generator_opt = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

discriminator = PatchGAN().to(device)
discriminator_opt = optim.SGD(discriminator.parameters(), lr=0.0002, momentum=0.9, nesterov=True)

criterion = nn.BCELoss()

In [None]:
# adjust parameters and epochs if needed!!!!!!!!
torch.manual_seed(42)
d_losses, g_losses = train_gan(100, discriminator, generator, discriminator_opt, generator_opt,
                                 criterion, train_loader, device, l1_lambda=5, label_smoothing=True,
                                 add_noise=True, save_checkpoints=True, save_losses=True, file_name="gen1")

In [None]:
save_model(generator, "generator_trained")

In [None]:
plot_losses(d_losses, g_losses, "Disc", "Gen")
plot_l(dataset[7653][1])
plot_model_pred(dataset[7653][1], generator, device)

In [None]:
torch.manual_seed(42)
avg_mse, std_mse, avg_psnr, std_psnr, avg_ssim, std_ssim, fid = evaluate_model(generator, test_loader, device)
print(f"Average MSE, STD: {avg_mse:.4f}, {std_mse:.4f}")
print(f"Average PSNR, STD: {avg_psnr:.4f}, {std_psnr:.4f}")
print(f"Average SSIM, STD: {avg_ssim:.4f}, {std_ssim:.4f}")
print(f"FID: {fid:.4f}")