In [None]:
import os
import sys
import numpy as np
import h5py
root = os.path.dirname(os.path.abspath(os.curdir))
sys.path.append(root)

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.utils.data import Subset

from predify.utils.training import train_pcoders, eval_pcoders

from networks_2022 import BranchedNetwork
from data.CleanSoundsDataset import CleanSoundsDataset, TrainCleanSoundsDataset, PsychophysicsCleanSoundsDataset

# Specify Network to train
TODO: This should be converted to a script that accepts arguments for which network to train

In [None]:
from pbranchednetwork_all import PBranchedNetwork_AllSeparateHP
PNetClass = PBranchedNetwork_AllSeparateHP
pnet_name = 'pnet_noisy'

# Parameters

In [3]:
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'Device: {DEVICE}')
BATCH_SIZE = 50
NUM_WORKERS = 2
PIN_MEMORY = True
NUM_EPOCHS = 70

lr = 1E-5
engram_dir = '/mnt/smb/locker/abbott-locker/hcnn/'
checkpoints_dir = f'{engram_dir}checkpoints/'
tensorboard_dir = f'{engram_dir}tensorboard/'
train_datafile = f'{engram_dir}clean_reconstruction_training_set.hdf5'
train_datafile = f'{engram_dir}hyperparameter_pooled_training_dataset_random_order_noNulls.hdf5'

Device: cuda:0


In [4]:
!nvidia-smi

Thu Aug 18 00:46:50 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.48.07    Driver Version: 515.48.07    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:23:00.0 Off |                  N/A |
| 27%   24C    P8     5W / 250W |      3MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

# Load network and optimizer

In [5]:
net = BranchedNetwork()
net.load_state_dict(torch.load(f'{engram_dir}networks_2022_weights.pt'))



<All keys matched successfully>

In [6]:
pnet = PNetClass(net, build_graph=True)

In [7]:
pnet.eval()

PBranchedNetwork_AllSeparateHP(
  (backbone): BranchedNetwork(
    (speech_branch): Sequential(
      (conv1): ConvLayer(
        (block): Sequential(
          (0): Conv2d(1, 96, kernel_size=(6, 14), stride=(3, 3), padding=(2, 6))
          (1): ReLU()
        )
      )
      (rnorm1): LRNorm(
        (block): LocalResponseNorm(5, alpha=0.005, beta=0.75, k=1.0)
      )
      (pool1): PoolLayer(
        (block): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
      )
      (conv2): ConvLayer(
        (block): Sequential(
          (0): Conv2d(96, 256, kernel_size=(5, 5), stride=(2, 2), padding=(1, 2))
          (1): ReLU()
        )
      )
      (rnorm2): LRNorm(
        (block): LocalResponseNorm(5, alpha=0.005, beta=0.75, k=1.0)
      )
      (pool2): PoolLayer(
        (block): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
      )
      (conv3): ConvLayer(
        (block): Sequential(
          (0): Conv2d(256, 512, kernel_size=

In [8]:
pnet.to(DEVICE)
optimizer = torch.optim.Adam(
    [{'params':getattr(pnet,f"pcoder{x+1}").pmodule.parameters(), 'lr':lr} for x in range(pnet.number_of_pcoders)],
    weight_decay=5e-4)

# Set up train/test dataset

In [9]:
train_dataset = CleanSoundsDataset(train_datafile, .9)
test_dataset = CleanSoundsDataset(train_datafile, .9, train = False)

In [10]:
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY
    )
eval_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY
    )

# Set up checkpoints and tensorboards

In [11]:
checkpoint_path = os.path.join(checkpoints_dir, f"{pnet_name}")
if not os.path.exists(checkpoint_path):
    os.makedirs(checkpoint_path)
checkpoint_path = os.path.join(checkpoint_path, pnet_name + '-{epoch}-{type}.pth')

# summarywriter
from torch.utils.tensorboard import SummaryWriter
tensorboard_path = os.path.join(tensorboard_dir, f"{pnet_name}")
if not os.path.exists(tensorboard_path):
    os.makedirs(tensorboard_path)
sumwriter = SummaryWriter(tensorboard_path, filename_suffix=f'')

# Train

In [None]:
loss_function = torch.nn.MSELoss()
for epoch in range(1, NUM_EPOCHS+1):
    train_pcoders(pnet, optimizer, loss_function, epoch, train_loader, DEVICE, sumwriter)
    eval_pcoders(pnet, loss_function, epoch, eval_loader, DEVICE, sumwriter)

    # save checkpoints every 5 epochs
    if epoch % 5 == 0:
        torch.save(pnet.state_dict(), checkpoint_path.format(epoch=epoch, type='regular'))

  "The default behavior for interpolate/upsample with float scale_factor changed "


Training Epoch: 1 [50/67482]	Loss: 53174.1797
Training Epoch: 1 [100/67482]	Loss: 49671.3828
Training Epoch: 1 [150/67482]	Loss: 47478.8164
Training Epoch: 1 [200/67482]	Loss: 44750.2344
Training Epoch: 1 [250/67482]	Loss: 42351.6719
Training Epoch: 1 [300/67482]	Loss: 39993.8867
Training Epoch: 1 [350/67482]	Loss: 37114.4609
Training Epoch: 1 [400/67482]	Loss: 35507.5508
Training Epoch: 1 [450/67482]	Loss: 33815.5430
Training Epoch: 1 [500/67482]	Loss: 31897.8457
Training Epoch: 1 [550/67482]	Loss: 29236.1973
Training Epoch: 1 [600/67482]	Loss: 28484.2305
Training Epoch: 1 [650/67482]	Loss: 26955.4805
Training Epoch: 1 [700/67482]	Loss: 25102.0117
Training Epoch: 1 [750/67482]	Loss: 23256.7266
Training Epoch: 1 [800/67482]	Loss: 22697.1855
Training Epoch: 1 [850/67482]	Loss: 20755.6328
Training Epoch: 1 [900/67482]	Loss: 19294.5742
Training Epoch: 1 [950/67482]	Loss: 19178.0781
Training Epoch: 1 [1000/67482]	Loss: 18101.5742
Training Epoch: 1 [1050/67482]	Loss: 16837.5820
Training Epo

Training Epoch: 1 [8700/67482]	Loss: 8381.4375
Training Epoch: 1 [8750/67482]	Loss: 8420.2168
Training Epoch: 1 [8800/67482]	Loss: 8542.5039
Training Epoch: 1 [8850/67482]	Loss: 8397.6846
Training Epoch: 1 [8900/67482]	Loss: 8493.6592
Training Epoch: 1 [8950/67482]	Loss: 8210.4521
Training Epoch: 1 [9000/67482]	Loss: 8498.1367
Training Epoch: 1 [9050/67482]	Loss: 8350.3857
Training Epoch: 1 [9100/67482]	Loss: 8123.1777
Training Epoch: 1 [9150/67482]	Loss: 8434.9854
Training Epoch: 1 [9200/67482]	Loss: 8449.2695
Training Epoch: 1 [9250/67482]	Loss: 8298.5674
Training Epoch: 1 [9300/67482]	Loss: 8287.6650
Training Epoch: 1 [9350/67482]	Loss: 8187.1670
Training Epoch: 1 [9400/67482]	Loss: 8324.9424
Training Epoch: 1 [9450/67482]	Loss: 7989.1079
Training Epoch: 1 [9500/67482]	Loss: 8111.6782
Training Epoch: 1 [9550/67482]	Loss: 8179.5229
Training Epoch: 1 [9600/67482]	Loss: 8136.7632
Training Epoch: 1 [9650/67482]	Loss: 8082.7246
Training Epoch: 1 [9700/67482]	Loss: 7855.3062
Training Epoc