In [5]:
#%pip install -q wandb
import wandb
wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\woodleighj\.netrc


True

In [2]:
import torch
import torchvision
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import transforms
from LMLTransformer import LMLTransformer
from ModelHelper import ModelHelper
from utils import SuperResolutionDataset
import yaml

In [7]:
# Loading config
with open('config.yaml', 'r') as file:
    config = yaml.safe_load(file)

In [None]:
# Data modifications/transforms 
hr_transforms = transforms.Compose([
    transforms.Lambda(lambda img: img.rotate(-90, expand=True) if img.height < img.width else img),
    transforms.CenterCrop((config["training"]["image_height"], config["training"]["image_width"])),
    transforms.ToTensor()
])

lr_transforms = transforms.Compose([
    transforms.Resize(int(config["training"]["training_scale_factor"])),
    transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2.0)),
    transforms.ToTensor()
])

In [None]:
# Loading datasets
dataset = torchvision.datasets.ImageFolder(root='./data/train', transform=hr_transforms)
dataset = SuperResolutionDataset(root='./data/train', hr_transforms=hr_transforms, lr_transforms=lr_transforms)

test_size = int(len(dataset) * config["training"]["testing_data_split"])
train_size = len(dataset) - test_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=config["training"]["batch_size"], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config["training"]["batch_size"], shuffle=True)

In [None]:
# Initilizing model, optimizer, and helper
model = LMLTransformer(
    n_blocks=config["model"]["n_blocks"],
    levels=config["model"]["levels"], window_size=config["model"]["n_blocks"],
    dim=config["model"]["dim"],
    scale_factor=config["model"]["scale_factor"]
)

optimzer = torch.optim.AdamW(model.parameters(), lr=config["training"]["learning_rate"])

helper = ModelHelper(model, optimzer)

In [8]:
if config["training"]["log"]:
    config["model"]["size"] = helper.get_parameter_count()
    wandb.init(
        project="SuperResolution",
        config=config
    )
    run_name = wandb.run.name

[34m[1mwandb[0m: Currently logged in as: [33mwoodleighj[0m ([33mjackwoodleigh[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
helper.train_model(
    train_loader,
    test_loader,
    config["training"]["epoch"],
    config["model"]["batches_per_epoch"],
    config["model"]["perceptual_loss_scale"]
)
wandb.finish()

VBox(children=(Label(value='0.013 MB of 0.013 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))