In [None]:
# import numpy as np
# a = np.load('./predictedvgg/subj001/LOC_conv5/aardvark/aardvark.npy')

In [2]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
import torch.nn.functional as F
from utils import listdir

In [3]:
class PredictedVoxelDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.list_category = sorted(os.listdir('predicted_EVC_conv5'))[0:20]
        voxel_path = []
        categories = []
        for i, c in enumerate(self.list_category):
            voxel_path.extend([os.path.join(self.root_dir, c, s) for s in os.listdir(os.path.join(self.root_dir, c))])
            nSample = len([s for s in os.listdir(os.path.join(self.root_dir, c))])
            for n in range(nSample):
                categories.append(i)
        self.voxel_path = voxel_path
        self.categories = categories
    def __len__(self):
        return len(self.voxel_path)
    def __getitem__(self, idx):
        #category = self.list_category[idx]
        voxels = np.load(self.voxel_path[idx])[0]
        sample = {'voxel': voxels, 'category': self.categories[idx]}
        if self.transform:
            sample = self.transform(sample)
        return sample


In [4]:
voxel_dataset = PredictedVoxelDataset(root_dir='./predictedvggall_random/subj001/LOC_conv5')
print(len(voxel_dataset))

278


In [5]:
print(voxel_dataset[0]['voxel'].shape)

(200,)


In [6]:
# for i_batch, sample_batched in enumerate(dataloader):
#     print(i_batch, sample_batched['voxel'].shape, sample_batched['category'])

batch_size = 10
n_iters = 10000
epochs = n_iters / (len(voxel_dataset)/batch_size)
nVox = voxel_dataset[0]['voxel'].shape[0]
nClass = len(set(voxel_dataset.categories))
lr_rate = 0.1

In [7]:
#train_set, val_set = torch.utils.data.random_split(voxel_dataset, [20092, 600])
train_set, val_set = torch.utils.data.random_split(voxel_dataset, [len(voxel_dataset)-20, 20])

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)

In [8]:
class LinearClassifier(torch.nn.Module):
    def __init__(self, nVox, nClass):
        super(LinearClassifier, self).__init__()
        self.linear = torch.nn.Linear(nVox, nClass)
        
    def forward(self, x):
        x = self.linear(x)
        return x
    
class NonLinearClassifier(torch.nn.Module):
    def __init__(self, nVox, nClass):
        super(NonLinearClassifier, self).__init__()
        self.linear = torch.nn.Linear(nVox, nVox)
        self.linear2 = torch.nn.Linear(nVox, nClass)
        
    def forward(self, x):
        x = self.linear(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        return x    

In [9]:
#linear_model = LinearClassifier(nVox, nClass)
linear_model = NonLinearClassifier(nVox, nClass)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(linear_model.parameters(), lr=lr_rate)
#optimizer = torch.optim.Adam(linear_model.parameters(), lr=lr_rate)

In [10]:
# # for i, sample_batched in enumerate(train_loader):
# #     print(i, sample_batched['voxel'].size(), sample_batched['category'])
# print(len(test_loader))
# for sample in test_loader:
#     print(sample)

In [11]:
# train the linear model

iteration = 0
for epoch in range(int(epochs)):
    for i, sample_batched in enumerate(train_loader):
        voxels = Variable(torch.Tensor(sample_batched['voxel'].float()))
        categories = Variable(sample_batched['category'])
        optimizer.zero_grad()
        outputs = linear_model(voxels)

        loss = criterion(outputs, categories)
        loss.backward()
        optimizer.step()
        
        iteration+=1
        if iteration%100==0:
            correct = 0
            total = 0
            for samples in test_loader:
                voxels = Variable(torch.Tensor(samples['voxel'].float()))
                outputs = linear_model(voxels)
                _, predicted = torch.max(outputs.data, 1)
                total += outputs.size(0)
                if len(predicted) == len(categories):
                    correct += (predicted == categories).sum()
                    accuracy = 100 * correct.item() / total
                    print("Iteration: {}. Loss: {}. Accuracy: {}%".format(iteration, loss.item(), accuracy))

Iteration: 100. Loss: 3.013476848602295. Accuracy: 0.0%
Iteration: 100. Loss: 3.013476848602295. Accuracy: 5.0%
Iteration: 200. Loss: 2.747680187225342. Accuracy: 0.0%
Iteration: 200. Loss: 2.747680187225342. Accuracy: 5.0%
Iteration: 300. Loss: 2.885913848876953. Accuracy: 10.0%
Iteration: 300. Loss: 2.885913848876953. Accuracy: 15.0%
Iteration: 400. Loss: 2.5339102745056152. Accuracy: 0.0%
Iteration: 400. Loss: 2.5339102745056152. Accuracy: 10.0%
Iteration: 500. Loss: 2.603872299194336. Accuracy: 10.0%
Iteration: 500. Loss: 2.603872299194336. Accuracy: 5.0%
Iteration: 600. Loss: 2.4858298301696777. Accuracy: 0.0%
Iteration: 600. Loss: 2.4858298301696777. Accuracy: 5.0%
Iteration: 700. Loss: 2.4167847633361816. Accuracy: 10.0%
Iteration: 700. Loss: 2.4167847633361816. Accuracy: 15.0%
Iteration: 800. Loss: 2.82761812210083. Accuracy: 0.0%
Iteration: 800. Loss: 2.82761812210083. Accuracy: 0.0%
Iteration: 900. Loss: 3.0304291248321533. Accuracy: 0.0%
Iteration: 900. Loss: 3.0304291248321