In [1]:
import os
import sys
root = os.path.dirname(os.path.abspath(os.curdir))
sys.path.append(root)

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.utils.data import Subset

from predify.utils.training import train_pcoders, eval_pcoders

from networks_2022 import BranchedNetwork
from pbranchednetwork_a1 import PBranchedNetwork_A1SeparateHP
from data.CleanSoundsDataset import CleanSoundsDataset

# Parameters

In [2]:
DEVICE = 'cpu'
BATCH_SIZE = 256
NUM_WORKERS = 0
PIN_MEMORY = False
NUM_EPOCHS = 50
RANDOM_SEED = 42

lr = 1E-4
checkpoints_dir = '../../models/checkpoints/'
tensorboard_dir = '../../models/tensorboard/'
datafile = '../../data/seed_542_word_clean_random_order.hdf5'

# Load network and optimizer

In [3]:
net = BranchedNetwork()



In [4]:
pnet = PBranchedNetwork_A1SeparateHP(net, build_graph=True)

In [5]:
pnet.eval()
pnet.to(DEVICE)
optimizer = torch.optim.Adam(
    [{'params':getattr(pnet,f"pcoder{x+1}").pmodule.parameters(), 'lr':lr} for x in range(pnet.number_of_pcoders)],
    weight_decay=5e-4)

# Set up dataset

In [6]:
train_dataset = eval_dataset = CleanSoundsDataset(datafile, subset=1000)

  self.data = torch.tensor(f['data']).reshape((-1, 164, 400))


In [7]:
# sub-sampling: comment the following 4 lines to use the whole dataset
# train_indices = np.random.permutation(len(train_dataset))[:5000]
# eval_indices  = np.random.permutation(len(eval_dataset))[:500]
# train_dataset = Subset(train_dataset, train_indices)
# eval_dataset  = Subset(eval_dataset,  eval_indices)

In [8]:
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY
    )
eval_loader = DataLoader(
    eval_dataset,  batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY
    )

# Set up checkpoints and tensorboards

In [9]:
checkpoint_path = os.path.join(checkpoints_dir, f"pnet-a1")
if not os.path.exists(checkpoint_path):
    os.makedirs(checkpoint_path)
checkpoint_path = os.path.join(checkpoint_path, 'pnet-a1-{epoch}-{type}.pth')

# summarywriter
from torch.utils.tensorboard import SummaryWriter
tensorboard_path = os.path.join(tensorboard_dir, f"pnet-a1")
if not os.path.exists(tensorboard_path):
    os.makedirs(tensorboard_path)
sumwriter = SummaryWriter(tensorboard_path, filename_suffix=f'')

In [10]:
%load_ext tensorboard
%tensorboard --logdir tensorboard

# Train

In [None]:
loss_function = torch.nn.MSELoss()
for epoch in range(1, NUM_EPOCHS+1):
    train_pcoders(pnet, optimizer, loss_function, epoch, train_loader, DEVICE, sumwriter)
    eval_pcoders(pnet, loss_function, epoch, eval_loader, DEVICE, sumwriter)

    # track the reconstruction of a single evaluation image through epochs
    sumwriter.add_image('Training Feedback Weights/sample input', denormalize_torch_images(pnet.input_mem[0], MEAN, STD), epoch)
    sumwriter.add_image('Training Feedback Weights/sample reconstruction', denormalize_torch_images(pnet.pcoder1.prd[0], MEAN, STD), epoch)
    
    # save checkpoints every 5 epochs
    if epoch % 5 == 0:
        torch.save(pnet.state_dict(), checkpoint_path.format(epoch=epoch, type='regular'))

  self.prediction_error  = nn.functional.mse_loss(self.prd, target)
  return F.mse_loss(input, target, reduction=self.reduction)


Training Epoch: 1 [256/1000]	Loss: 0.0258
Training Epoch: 1 [512/1000]	Loss: 0.0149


In [None]:
train_loader

In [None]:
checkpoint_path