In [1]:
from pathlib import Path
import os

import matplotlib
import matplotlib.pyplot as plt

import numpy as np

import torch
import torch.nn.functional as F
import torch.optim

from predictive_coding.dataset import collate_fn, EnvironmentDataset
from predictive_coding.trainer import Trainer
from predictive_coding.models.models import Autoencoder, PredictiveCoder

In [2]:
train_dataset = EnvironmentDataset(Path("../datasets/train-dataset"))
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
)

val_dataset = EnvironmentDataset(Path("../datasets/val-dataset"))
val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=100,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
)


In [3]:
experiment_name = 'predictive-coding'
model = PredictiveCoder(in_channels=3, out_channels=3, layers=[2, 2, 2, 2], seq_len=10, num_skip=3)
model = model.to('cuda:0')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=5e-6)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-1, epochs=200, steps_per_epoch=len(train_dataloader))

ckpt_path = os.path.abspath('./experiments/' + experiment_name)
if not os.path.exists(ckpt_path):
    os.makedirs(ckpt_path, exist_ok=True)
trainer = Trainer(model, optimizer, scheduler, train_dataloader, val_dataloader,
                  checkpoint_path=ckpt_path)
trainer.fit(num_epochs=200)


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/205 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
experiment_name = 'autoencoder'
model = Autoencoder(in_channels=3, out_channels=3, layers=[2, 2, 2, 2])
model = model.to('cuda:0')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=5e-6)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-1, epochs=200, steps_per_epoch=len(train_dataloader))

ckpt_path = os.path.abspath('./experiments/' + experiment_name)
if not os.path.exists(ckpt_path):
    os.makedirs(ckpt_path, exist_ok=True)
trainer = Trainer(model, optimizer, scheduler, train_dataloader, val_dataloader,
                  checkpoint_path=ckpt_path)
trainer.fit(num_epochs=200)
