In [1]:
import os
import sys
import numpy as np
import h5py
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.utils.data import Subset
sys.path.append("/mnt/smb/locker/abbott-locker/es3773/from_home/hallucnn/src/data")
sys.path.append('/mnt/smb/locker/abbott-locker/es3773/from_home/hallucnn_orig/src/models/')
from networks_2022 import BranchedNetwork
from MergedNoisyDataset import MergedNoisyDataset


In [2]:
pnet_name = 'pnet_deep_feature_loss_merged_noisy_dataset_v3_with_eval_loss'
_train_datafile = 'clean_reconstruction_training_set'
SoundsDataset = MergedNoisyDataset
dset_kwargs = {}

In [3]:
sys.path.append("/home/es3773/hallucnn/src/models")
from pbranchednetwork_all import PBranchedNetwork_AllSeparateHP
PNetClass = PBranchedNetwork_AllSeparateHP

In [4]:

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'Device: {DEVICE}')
BATCH_SIZE = 30
NUM_WORKERS = 2
PIN_MEMORY = True
NUM_EPOCHS = 300

lr = 1E-5
engram_dir = '/mnt/smb/locker/abbott-locker/hcnn/'
checkpoints_dir = f'{engram_dir}checkpoints/'
tensorboard_dir = f'{engram_dir}tensorboard/'
train_datafile = f'{engram_dir}{_train_datafile}.hdf5'

Device: cuda:0


In [5]:
net = BranchedNetwork()
net.load_state_dict(torch.load(f'{engram_dir}networks_2022_weights.pt'))



<All keys matched successfully>

In [6]:

pnet = PNetClass(net, build_graph=True)

pnet.eval()

pnet.to(DEVICE)
optimizer = torch.optim.Adam(
    [{'params':getattr(pnet,f"pcoder{x+1}").pmodule.parameters(), 'lr':lr} for x in range(pnet.number_of_pcoders)],
    weight_decay=5e-4)

In [7]:
train_dataset = MergedNoisyDataset(subset=0.9, train=True)
        
test_dataset = MergedNoisyDataset(subset=0.9, train=False)
train_dataset.n_data
test_dataset.n_data  

7497

In [8]:
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY
    )
eval_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY
    )

In [11]:
from typing import Callable
import torch

def train_pcoders(net: torch.nn.Module, optimizer: torch.optim.Optimizer, loss_function: Callable, epoch: int, train_loader: torch.utils.data.DataLoader, device: str, writer: torch.utils.tensorboard.SummaryWriter=None):
    r"""
    Trains the feedback modules of PCoders using a distance between the prediction of a PCoder and the
    representation of the PCoder below.
    Args:
        net (torch.nn.Module): Predified network including all the PCoders
        optimizer (torch.optim.Optimizer): PyTorch-compatible optimizer object
        loss_function (Callable): A callable function that receives two tensors and returns the distance between them
        epoch (int): Training epoch number
        train_loader (torch.utils.data.DataLoader): DataLoader for training samples
        writer (torch.utils.tensorboard.SummaryWrite, optional): Tensorboard summary writer to track training history. Default: None
        device (str): Training device (e.g. 'cpu', 'cuda:0')
    """
    layers = ['conv1', 'conv2', 'conv3', 'conv4_W', 'conv5_W', 'fc6_W']
    net.train()
    net.backbone.eval()
    activations_dir = '/mnt/smb/locker/abbott-locker/hcnn/3_activations/FF/'
    hdf5_outpath = f'{activations_dir}clean_hyperparameter_matched_training_set_activations.hdf5'
    with h5py.File(hdf5_outpath, 'r') as f_in: 
        nb_trained_samples = 0
        for batch_index, (images, _) in enumerate(train_loader):
            batch_size = np.shape(images)[0]
            net.reset()
            images = images.to(device)
            
            optimizer.zero_grad()
            outputs = net(images)
            for i in range(net.number_of_pcoders):
                if i == 0:
                    a = loss_function(net.pcoder1.prd, images)
                    loss = a
                else:
                    pcoder_pre =f_in[layers[i-1]+'_activations'][batch_index*batch_size: (batch_index+1)*batch_size] #getattr(net, f"pcoder{i}")
                    pcoder_pre = torch.tensor(pcoder_pre).to(device)
                    
                    pcoder_curr =  getattr(net, f"pcoder{i+1}")
                    a = loss_function(pcoder_curr.prd, pcoder_pre)
                    loss += a
                if writer is not None:
                    writer.add_scalar(f"MSE Train/PCoder{i+1}", a.item(), (epoch-1) * len(train_loader) + batch_index)

            nb_trained_samples += images.shape[0]

            loss.backward()
            optimizer.step()

            print('Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}'.format(
                loss.item(),
                epoch=epoch,
                trained_samples=nb_trained_samples,
                total_samples=len(train_loader.dataset)
            ))
            if writer is not None:
                writer.add_scalar(f"MSE Train/Sum", loss.item(), (epoch-1) * len(train_loader) + batch_index)


def eval_pcoders(offset: int, net: torch.nn.Module, loss_function: Callable, epoch: int, eval_loader: torch.utils.data.DataLoader, device: str, writer: torch.utils.tensorboard.SummaryWriter=None):
    """
    Evaluates the feedback modules of PCoders using a distance between the prediction of a PCoder and the
    representation of the PCoder below.
    Args:
        net (torch.nn.Module): Predified network including all the PCoders
        loss_function (Callable): A callable function that receives two tensors and returns the distance between them
        epoch (int): Evaluation epoch number
        test_loader (torch.utils.data.DataLoader): DataLoader for evaluation samples
        writer (torch.utils.tensorboard.SummaryWrite, optional): Tensorboard summary writer to track evaluation history. Default: None
        device (str): Training device (e.g. 'cpu', 'cuda:0')
    """
    layers = ['conv1', 'conv2', 'conv3', 'conv4_W', 'conv5_W', 'fc6_W']
    activations_dir = '/mnt/smb/locker/abbott-locker/hcnn/3_activations/FF/'
    hdf5_outpath = f'{activations_dir}clean_hyperparameter_matched_training_set_activations.hdf5'
    with h5py.File(hdf5_outpath, 'r') as f_in: 
        net.eval()
        
        
        ## NEED TO ADD OFFSET TO ACCOUNT FOR 

        final_loss = [0 for i in range(net.number_of_pcoders)]
        for batch_index, (images, _) in enumerate(eval_loader):
            batch_size = np.shape(images)[0]
            net.reset()
            images = images.to(device)
            with torch.no_grad():
                outputs = net(images)
            for i in range(net.number_of_pcoders):
                print(i)
                if i == 0:
                    final_loss[i] += loss_function(net.pcoder1.prd, images).item()
                else:
                    
                    start = ((batch_index*batch_size) + offset) 
                    end = (((batch_index+1)*batch_size) + offset) 
                    
                    pcoder_pre =f_in[layers[i-1]+'_activations'][start:end] #getattr(net, f"pcoder{i}")
                 
                    pcoder_curr = getattr(net, f"pcoder{i+1}")
                    final_loss[i] += loss_function(pcoder_curr.prd, pcoder_pre).item()

        loss_sum = 0
        for i in range(net.number_of_pcoders):
            final_loss[i] /= len(eval_loader)
            loss_sum += final_loss[i]
            print(final_loss)
            if writer is not None:
                writer.add_scalar(f"MSE Eval/PCoder{i+1}", final_loss[i], epoch-1)


        print('Training Epoch: {epoch} [{evaluated_samples}/{total_samples}]\tLoss: {:0.4f}'.format(
            loss_sum,
            epoch=epoch,
            evaluated_samples=len(eval_loader.dataset),
            total_samples=len(eval_loader.dataset)
        ))
        if writer is not None:
            writer.add_scalar(f"MSE Eval/Sum", loss_sum, epoch-1)

In [12]:
checkpoint_path = os.path.join(checkpoints_dir, f"{pnet_name}")
if not os.path.exists(checkpoint_path):
    os.makedirs(checkpoint_path)
checkpoint_path = os.path.join(checkpoint_path, pnet_name + '-{epoch}-{type}.pth')

# summarywriter
from torch.utils.tensorboard import SummaryWriter
tensorboard_path = os.path.join(tensorboard_dir, f"{pnet_name}")
if not os.path.exists(tensorboard_path):
    os.makedirs(tensorboard_path)
sumwriter = SummaryWriter(tensorboard_path, filename_suffix=f'')

In [None]:
loss_function = torch.nn.L1Loss(reduction='mean')

for epoch in range(1, NUM_EPOCHS+1):
    train_pcoders(pnet, optimizer, loss_function, epoch, train_loader, DEVICE, sumwriter)
    eval_pcoders(pnet, loss_function, epoch, eval_loader, DEVICE, sumwriter)

    # save checkpoints every 5 "epochs
    if epoch % 5 == 0:
        torch.save(train_dataset.n_data, pnet.state_dict(), checkpoint_path.format(epoch=epoch, type='regular'))

  "The default behavior for interpolate/upsample with float scale_factor changed "


Training Epoch: 1 [30/67482]	Loss: 556.5466
Training Epoch: 1 [60/67482]	Loss: 554.0670
Training Epoch: 1 [90/67482]	Loss: 550.2391
Training Epoch: 1 [120/67482]	Loss: 544.7906
Training Epoch: 1 [150/67482]	Loss: 535.3878
Training Epoch: 1 [180/67482]	Loss: 523.2587
Training Epoch: 1 [210/67482]	Loss: 518.9567
Training Epoch: 1 [240/67482]	Loss: 500.9839
Training Epoch: 1 [270/67482]	Loss: 506.0135
Training Epoch: 1 [300/67482]	Loss: 503.4045
Training Epoch: 1 [330/67482]	Loss: 482.3760
Training Epoch: 1 [360/67482]	Loss: 489.2063
Training Epoch: 1 [390/67482]	Loss: 478.5609
Training Epoch: 1 [420/67482]	Loss: 498.2364
Training Epoch: 1 [450/67482]	Loss: 472.9398
Training Epoch: 1 [480/67482]	Loss: 467.1927
Training Epoch: 1 [510/67482]	Loss: 457.7351
Training Epoch: 1 [540/67482]	Loss: 444.4552
Training Epoch: 1 [570/67482]	Loss: 449.5471
Training Epoch: 1 [600/67482]	Loss: 454.9693
Training Epoch: 1 [630/67482]	Loss: 444.1888
Training Epoch: 1 [660/67482]	Loss: 451.5621
Training Epoc

Training Epoch: 1 [5400/67482]	Loss: 135.0438
Training Epoch: 1 [5430/67482]	Loss: 134.0845
Training Epoch: 1 [5460/67482]	Loss: 135.0941
Training Epoch: 1 [5490/67482]	Loss: 134.4748
Training Epoch: 1 [5520/67482]	Loss: 129.8926
Training Epoch: 1 [5550/67482]	Loss: 136.0229
Training Epoch: 1 [5580/67482]	Loss: 136.4851
Training Epoch: 1 [5610/67482]	Loss: 129.1438
Training Epoch: 1 [5640/67482]	Loss: 136.4847
Training Epoch: 1 [5670/67482]	Loss: 140.2888
Training Epoch: 1 [5700/67482]	Loss: 127.4764
Training Epoch: 1 [5730/67482]	Loss: 140.9365
Training Epoch: 1 [5760/67482]	Loss: 133.6520
Training Epoch: 1 [5790/67482]	Loss: 130.1189
Training Epoch: 1 [5820/67482]	Loss: 129.0533
Training Epoch: 1 [5850/67482]	Loss: 136.3812
Training Epoch: 1 [5880/67482]	Loss: 138.1537
Training Epoch: 1 [5910/67482]	Loss: 139.5058
Training Epoch: 1 [5940/67482]	Loss: 132.3652
Training Epoch: 1 [5970/67482]	Loss: 128.4603
Training Epoch: 1 [6000/67482]	Loss: 136.4644
Training Epoch: 1 [6030/67482]	Los

Training Epoch: 1 [10740/67482]	Loss: 114.9808
Training Epoch: 1 [10770/67482]	Loss: 110.3417
Training Epoch: 1 [10800/67482]	Loss: 111.2141
Training Epoch: 1 [10830/67482]	Loss: 110.8032
Training Epoch: 1 [10860/67482]	Loss: 106.8683
Training Epoch: 1 [10890/67482]	Loss: 109.6098
Training Epoch: 1 [10920/67482]	Loss: 112.2244
Training Epoch: 1 [10950/67482]	Loss: 109.5504
Training Epoch: 1 [10980/67482]	Loss: 106.3094
Training Epoch: 1 [11010/67482]	Loss: 104.9230
Training Epoch: 1 [11040/67482]	Loss: 114.9384
Training Epoch: 1 [11070/67482]	Loss: 110.2971
Training Epoch: 1 [11100/67482]	Loss: 108.5905
Training Epoch: 1 [11130/67482]	Loss: 111.5316
Training Epoch: 1 [11160/67482]	Loss: 113.6264
Training Epoch: 1 [11190/67482]	Loss: 109.6748
Training Epoch: 1 [11220/67482]	Loss: 107.7052
Training Epoch: 1 [11250/67482]	Loss: 110.5952
Training Epoch: 1 [11280/67482]	Loss: 112.3052
Training Epoch: 1 [11310/67482]	Loss: 108.6384
Training Epoch: 1 [11340/67482]	Loss: 108.7230
Training Epoc

Training Epoch: 1 [16050/67482]	Loss: 91.8521
Training Epoch: 1 [16080/67482]	Loss: 87.4816
Training Epoch: 1 [16110/67482]	Loss: 90.1961
Training Epoch: 1 [16140/67482]	Loss: 90.3085
Training Epoch: 1 [16170/67482]	Loss: 88.1848
Training Epoch: 1 [16200/67482]	Loss: 91.6718
Training Epoch: 1 [16230/67482]	Loss: 89.2986
Training Epoch: 1 [16260/67482]	Loss: 89.6663
Training Epoch: 1 [16290/67482]	Loss: 91.7166
Training Epoch: 1 [16320/67482]	Loss: 86.0477
Training Epoch: 1 [16350/67482]	Loss: 88.3860
Training Epoch: 1 [16380/67482]	Loss: 89.1201
Training Epoch: 1 [16410/67482]	Loss: 87.1873
Training Epoch: 1 [16440/67482]	Loss: 90.9019
Training Epoch: 1 [16470/67482]	Loss: 87.4648
Training Epoch: 1 [16500/67482]	Loss: 87.9325
Training Epoch: 1 [16530/67482]	Loss: 87.1435
Training Epoch: 1 [16560/67482]	Loss: 87.0647
Training Epoch: 1 [16590/67482]	Loss: 85.5185
Training Epoch: 1 [16620/67482]	Loss: 92.0339
Training Epoch: 1 [16650/67482]	Loss: 94.7238
Training Epoch: 1 [16680/67482]	Lo

Training Epoch: 1 [21420/67482]	Loss: 71.1836
Training Epoch: 1 [21450/67482]	Loss: 70.0055
Training Epoch: 1 [21480/67482]	Loss: 71.5196
Training Epoch: 1 [21510/67482]	Loss: 69.9345
Training Epoch: 1 [21540/67482]	Loss: 73.5457
Training Epoch: 1 [21570/67482]	Loss: 70.6588
Training Epoch: 1 [21600/67482]	Loss: 68.6595
Training Epoch: 1 [21630/67482]	Loss: 70.7238
Training Epoch: 1 [21660/67482]	Loss: 69.8450
Training Epoch: 1 [21690/67482]	Loss: 69.1971
Training Epoch: 1 [21720/67482]	Loss: 70.5422
Training Epoch: 1 [21750/67482]	Loss: 69.8813
Training Epoch: 1 [21780/67482]	Loss: 70.3521
Training Epoch: 1 [21810/67482]	Loss: 70.0435
Training Epoch: 1 [21840/67482]	Loss: 71.2357
Training Epoch: 1 [21870/67482]	Loss: 69.9645
Training Epoch: 1 [21900/67482]	Loss: 66.0049
Training Epoch: 1 [21930/67482]	Loss: 69.1803
Training Epoch: 1 [21960/67482]	Loss: 68.9389
Training Epoch: 1 [21990/67482]	Loss: 72.0743
Training Epoch: 1 [22020/67482]	Loss: 69.4739
Training Epoch: 1 [22050/67482]	Lo

Training Epoch: 1 [26790/67482]	Loss: 58.2194
Training Epoch: 1 [26820/67482]	Loss: 56.8514
Training Epoch: 1 [26850/67482]	Loss: 56.5717
Training Epoch: 1 [26880/67482]	Loss: 55.4524
Training Epoch: 1 [26910/67482]	Loss: 55.6440
Training Epoch: 1 [26940/67482]	Loss: 55.5832
Training Epoch: 1 [26970/67482]	Loss: 55.9259
Training Epoch: 1 [27000/67482]	Loss: 56.5553
Training Epoch: 1 [27030/67482]	Loss: 58.0639
Training Epoch: 1 [27060/67482]	Loss: 56.0693
Training Epoch: 1 [27090/67482]	Loss: 56.9625
Training Epoch: 1 [27120/67482]	Loss: 57.2258
Training Epoch: 1 [27150/67482]	Loss: 56.1491
Training Epoch: 1 [27180/67482]	Loss: 58.1421
Training Epoch: 1 [27210/67482]	Loss: 56.9554
Training Epoch: 1 [27240/67482]	Loss: 57.5479
Training Epoch: 1 [27270/67482]	Loss: 57.1936
Training Epoch: 1 [27300/67482]	Loss: 56.5594
Training Epoch: 1 [27330/67482]	Loss: 54.3352
Training Epoch: 1 [27360/67482]	Loss: 55.7553
Training Epoch: 1 [27390/67482]	Loss: 57.0289
Training Epoch: 1 [27420/67482]	Lo

Training Epoch: 1 [32160/67482]	Loss: 48.1547
Training Epoch: 1 [32190/67482]	Loss: 48.9146
Training Epoch: 1 [32220/67482]	Loss: 49.2791
Training Epoch: 1 [32250/67482]	Loss: 48.9784
Training Epoch: 1 [32280/67482]	Loss: 49.1673
Training Epoch: 1 [32310/67482]	Loss: 47.7083
Training Epoch: 1 [32340/67482]	Loss: 50.1835
Training Epoch: 1 [32370/67482]	Loss: 50.4576
Training Epoch: 1 [32400/67482]	Loss: 47.7796
Training Epoch: 1 [32430/67482]	Loss: 49.1076
Training Epoch: 1 [32460/67482]	Loss: 48.2294
Training Epoch: 1 [32490/67482]	Loss: 47.1699
Training Epoch: 1 [32520/67482]	Loss: 48.2209
Training Epoch: 1 [32550/67482]	Loss: 47.6171
Training Epoch: 1 [32580/67482]	Loss: 47.8750
Training Epoch: 1 [32610/67482]	Loss: 47.9419
Training Epoch: 1 [32640/67482]	Loss: 49.0693
Training Epoch: 1 [32670/67482]	Loss: 48.3703
Training Epoch: 1 [32700/67482]	Loss: 48.2216
Training Epoch: 1 [32730/67482]	Loss: 47.8462
Training Epoch: 1 [32760/67482]	Loss: 48.9074
Training Epoch: 1 [32790/67482]	Lo

Training Epoch: 1 [37530/67482]	Loss: 44.0758
Training Epoch: 1 [37560/67482]	Loss: 43.5917
Training Epoch: 1 [37590/67482]	Loss: 43.6650
Training Epoch: 1 [37620/67482]	Loss: 43.5276
Training Epoch: 1 [37650/67482]	Loss: 43.5766
Training Epoch: 1 [37680/67482]	Loss: 43.8860
Training Epoch: 1 [37710/67482]	Loss: 43.5808
Training Epoch: 1 [37740/67482]	Loss: 43.3974
Training Epoch: 1 [37770/67482]	Loss: 44.1185
Training Epoch: 1 [37800/67482]	Loss: 43.3031
Training Epoch: 1 [37830/67482]	Loss: 43.2447
Training Epoch: 1 [37860/67482]	Loss: 44.5105
Training Epoch: 1 [37890/67482]	Loss: 44.5308
Training Epoch: 1 [37920/67482]	Loss: 43.9460
Training Epoch: 1 [37950/67482]	Loss: 43.7863
Training Epoch: 1 [37980/67482]	Loss: 43.4628
Training Epoch: 1 [38010/67482]	Loss: 43.2648
Training Epoch: 1 [38040/67482]	Loss: 43.9378
Training Epoch: 1 [38070/67482]	Loss: 43.3513
Training Epoch: 1 [38100/67482]	Loss: 42.9002
Training Epoch: 1 [38130/67482]	Loss: 43.2166
Training Epoch: 1 [38160/67482]	Lo