In [1]:
import torch
import numpy as np
import scipy
import torchvision.models as models
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from skimage import io, transform
import matplotlib.pyplot as plt
from collections import OrderedDict
from sklearn import decomposition
from sklearn import svm
from matplotlib import cm
import torch.nn.functional as F
from torch import nn
import math
import pandas as pd
import os
import shapedimutils as s

%matplotlib inline

root = '/mnt/neurosphere/serenceslab2/maggie/shapeDim/'

image_dir = os.path.join(root,'Stimuli','AmpGrid3_adj_full_grey_small')
save_model_dir = os.path.join(root,'Modeling','saved_models')
if not os.path.exists(save_model_dir):
    os.makedirs(save_model_dir)
    
csv_file = os.path.join(image_dir, 'shape_labels_all.csv')
shape_labels = pd.read_csv(csv_file)

In [4]:
# testing out the dataset split generation

trn, val, tst  = s.get_dataset_splits(image_dir)
print([len(trn), len(val), len(tst)])
print(trn[0]['task_labels'])
print(val[0]['task_labels'])
print(tst[0]['task_labels'])
# for tt in range(10):
#     print(tst[tt]['task_labels'][0])
print(np.shape(trn[0]['image']))

[2107, 234, 260]
[1 1 0]
[0 0 0]
[0 0 0]
torch.Size([3, 224, 224])


In [5]:
# test the fine-tuning process
for task in range(3):
    s.fine_tune_alexnet_binary(image_dir, save_model_dir, task=task, maxsteps=201)

step 0, trn loss = 1.09290, trn accuracy = 0.41000
step 1, trn loss = 1.00200, trn accuracy = 0.36000
step 2, trn loss = 0.97736, trn accuracy = 0.40000
step 3, trn loss = 0.80643, trn accuracy = 0.52000
step 4, trn loss = 0.87658, trn accuracy = 0.41000
step 5, trn loss = 0.88047, trn accuracy = 0.42000
step 6, trn loss = 0.92782, trn accuracy = 0.41000
step 7, trn loss = 0.91079, trn accuracy = 0.41000
step 8, trn loss = 0.87093, trn accuracy = 0.38000
step 9, trn loss = 0.83511, trn accuracy = 0.44000
step 10, trn loss = 0.87489, trn accuracy = 0.39000
step 10, val loss = 0.81901, val accuracy = 0.45
step 11, trn loss = 0.89669, trn accuracy = 0.38000
step 12, trn loss = 0.80931, trn accuracy = 0.47000
step 13, trn loss = 0.94467, trn accuracy = 0.37000
step 14, trn loss = 0.86783, trn accuracy = 0.46000
step 15, trn loss = 0.81629, trn accuracy = 0.50000
step 16, trn loss = 0.85814, trn accuracy = 0.46000
step 17, trn loss = 0.94674, trn accuracy = 0.35000
step 18, trn loss = 0.890

step 142, trn loss = 0.63196, trn accuracy = 0.64000
step 143, trn loss = 0.53996, trn accuracy = 0.76000
step 144, trn loss = 0.64984, trn accuracy = 0.66000
step 145, trn loss = 0.50773, trn accuracy = 0.76000
step 146, trn loss = 0.55580, trn accuracy = 0.73000
step 147, trn loss = 0.51891, trn accuracy = 0.74000
step 148, trn loss = 0.58660, trn accuracy = 0.71000
step 149, trn loss = 0.45970, trn accuracy = 0.81000
step 150, trn loss = 0.47653, trn accuracy = 0.80000
step 150, val loss = 0.49209, val accuracy = 0.74
step 151, trn loss = 0.50496, trn accuracy = 0.76000
step 152, trn loss = 0.50446, trn accuracy = 0.79000
step 153, trn loss = 0.53128, trn accuracy = 0.71429
starting epoch 7
step 154, trn loss = 0.54796, trn accuracy = 0.76000
step 155, trn loss = 0.57695, trn accuracy = 0.69000
step 156, trn loss = 0.53662, trn accuracy = 0.76000
step 157, trn loss = 0.51450, trn accuracy = 0.75000
step 158, trn loss = 0.56799, trn accuracy = 0.69000
step 159, trn loss = 0.58415, tr

step 81, trn loss = 0.54611, trn accuracy = 0.70000
step 82, trn loss = 0.47882, trn accuracy = 0.78000
step 83, trn loss = 0.47029, trn accuracy = 0.80000
step 84, trn loss = 0.60562, trn accuracy = 0.69000
step 85, trn loss = 0.55649, trn accuracy = 0.68000
step 86, trn loss = 0.49626, trn accuracy = 0.79000
step 87, trn loss = 0.49614, trn accuracy = 0.71429
starting epoch 4
step 88, trn loss = 0.54144, trn accuracy = 0.73000
step 89, trn loss = 0.52188, trn accuracy = 0.76000
step 90, trn loss = 0.48278, trn accuracy = 0.79000
step 90, val loss = 0.40273, val accuracy = 0.91
step 91, trn loss = 0.50099, trn accuracy = 0.78000
step 92, trn loss = 0.48924, trn accuracy = 0.79000
step 93, trn loss = 0.49511, trn accuracy = 0.74000
step 94, trn loss = 0.48647, trn accuracy = 0.77000
step 95, trn loss = 0.49890, trn accuracy = 0.81000
step 96, trn loss = 0.50367, trn accuracy = 0.72000
step 97, trn loss = 0.44430, trn accuracy = 0.86000
step 98, trn loss = 0.46484, trn accuracy = 0.7900

step 20, trn loss = 0.71404, trn accuracy = 0.59000
step 20, val loss = 0.67357, val accuracy = 0.53
step 21, trn loss = 1.33074, trn accuracy = 0.28571
starting epoch 1
step 22, trn loss = 0.67586, trn accuracy = 0.66000
step 23, trn loss = 0.68998, trn accuracy = 0.59000
step 24, trn loss = 0.80468, trn accuracy = 0.51000
step 25, trn loss = 0.66567, trn accuracy = 0.65000
step 26, trn loss = 0.66188, trn accuracy = 0.62000
step 27, trn loss = 0.76009, trn accuracy = 0.55000
step 28, trn loss = 0.66194, trn accuracy = 0.67000
step 29, trn loss = 0.66310, trn accuracy = 0.66000
step 30, trn loss = 0.74863, trn accuracy = 0.54000
step 30, val loss = 0.67394, val accuracy = 0.54
step 31, trn loss = 0.79289, trn accuracy = 0.51000
step 32, trn loss = 0.70978, trn accuracy = 0.55000
step 33, trn loss = 0.66378, trn accuracy = 0.62000
step 34, trn loss = 0.68881, trn accuracy = 0.65000
step 35, trn loss = 0.66439, trn accuracy = 0.61000
step 36, trn loss = 0.67026, trn accuracy = 0.63000
s

step 160, val loss = 0.60390, val accuracy = 0.65
step 161, trn loss = 0.67276, trn accuracy = 0.68000
step 162, trn loss = 0.57687, trn accuracy = 0.72000
step 163, trn loss = 0.69483, trn accuracy = 0.57000
step 164, trn loss = 0.78025, trn accuracy = 0.49000
step 165, trn loss = 0.62934, trn accuracy = 0.64000
step 166, trn loss = 0.69059, trn accuracy = 0.57000
step 167, trn loss = 0.66663, trn accuracy = 0.61000
step 168, trn loss = 0.61503, trn accuracy = 0.64000
step 169, trn loss = 0.67685, trn accuracy = 0.58000
step 170, trn loss = 0.68375, trn accuracy = 0.60000
step 170, val loss = 0.58137, val accuracy = 0.70
step 171, trn loss = 0.69422, trn accuracy = 0.64000
step 172, trn loss = 0.69393, trn accuracy = 0.58000
step 173, trn loss = 0.56339, trn accuracy = 0.68000
step 174, trn loss = 0.64176, trn accuracy = 0.63000
step 175, trn loss = 0.59903, trn accuracy = 0.71429
starting epoch 8
step 176, trn loss = 0.66156, trn accuracy = 0.60000
step 177, trn loss = 0.63979, trn a

In [3]:
# test the evaluation function for the fine-tuned model
s.eval_alexnet_fine_tune(image_dir, save_model_dir, task=0)

loading from /mnt/neurosphere/serenceslab2/maggie/shapeDim/Modeling/saved_models/AlexNet_finetune_task0.pt

step 4, val loss = 0.84007, val accuracy = 0.35
step 4, test loss = 0.88342, test accuracy = 0.33


In [None]:
## OLD
# what object category does it think the shapes are? 
final_activ = out.detach().numpy()
np.shape(final_activ)
np.argmax(final_activ,1)

array([530, 530, 530, 530, 530, 530, 530, 530, 530, 530, 530, 530, 818,
       530, 530, 530])

In [None]:
## OLD
# try to classify category based on activs at different layers

layers2do = np.arange(0,13,1)
axis2discrim = 0  # 0 or 1
center = 2.5  # center of shape space is the "boundary"
image_labels = np.int64(coords2load[:,axis2discrim]>center) # create binary labels
image_inds = np.arange(0,nIms)

for ll in layers2do:
  # first, reshape to [nIms x nUnits]
  # disregarding difference between channels/spatial dims for now
  nUnitsTotal = np.prod(np.shape(activ[ll])[1:])
  activ_full = np.reshape(activ[ll].numpy(),[nIms, nUnitsTotal])
  np.shape(activ_full)

  # cross validated decoding leaving one im out at a time
  predlabs = np.zeros(np.shape(image_labels))
  for ii in range(nIms):
    
    trnlabs = image_labels[image_inds!=ii]
    tstlabs = image_labels[image_inds==ii]
    trndat = activ_full[image_inds!=ii,:]
    tstdat = activ_full[image_inds==ii,:]

    # train/test SVM classifier
    classifier = svm.SVC()
    classifier.fit(trndat, trnlabs)
    pred = classifier.predict(tstdat)
    predlabs[ii] = pred

  acc = np.mean(predlabs==image_labels)
  print('Layer %s: acc = %.2f'%(layer_names[ll],acc))
  

Layer 1_Conv2d: acc = 0.56
Layer 1_ReLU: acc = 0.56
Layer 1_MaxPool2d: acc = 0.94
Layer 2_Conv2d: acc = 0.94
Layer 2_ReLU: acc = 0.94
Layer 2_MaxPool2d: acc = 1.00
Layer 3_Conv2d: acc = 1.00
Layer 3_ReLU: acc = 1.00
Layer 4_Conv2d: acc = 1.00
Layer 4_ReLU: acc = 1.00
Layer 5_Conv2d: acc = 0.94
Layer 5_ReLU: acc = 0.94
Layer 5_MaxPool2d: acc = 0.88
