In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import glob
import os.path as osp
import numpy as np
from PIL import Image

In [2]:
# Use GPU if available, otherwise stick with cpu
use_cuda = torch.cuda.is_available()
torch.manual_seed(123)
device = torch.device(cuda if use_cuda else "cpu")
print(device)

cpu


## 1. Custom Dataset
PyTorch has many built-in datasets such as MNIST and CIFAR. In this tutorial, we demonstrate how to write your own dataset by implementing a custom MNIST dataset class. Use [this link](https://github.com/myleott/mnist_png/blob/master/mnist_png.tar.gz?raw=true) to download the mnist png dataset.

In [3]:
class PDBSS_ATOMS(Dataset):
    """
    A customized data loader for PDBSS (PDB Snapshots Database).
    """
    def __init__(self,
                 root,
                 transform=None,
                 preload=False):
        """ Intialize the PDBSS dataset
        
        Args:
            - root: root directory of the dataset
            - tranform: a custom tranform function
            - preload: if preload the dataset into memory
        """
        self.images = None
        self.labels = None
        self.filenames = []
        self.root = root
        self.transform = transform

        # read filenames
        for i in range(740): #740 categories
            filenames = glob.glob(osp.join(root, str(i), '*.png'))
            for fn in filenames:
                self.filenames.append((fn, i)) # (filename, label) pair
                
        # if preload dataset into memory
        if preload:
            self._preload()
            
        self.len = len(self.filenames)
                              
    def _preload(self):
        """
        Preload dataset to memory
        """
        self.labels = []
        self.images = []
        for image_fn, label in self.filenames:            
            # load images
            image = Image.open(image_fn)
            # avoid too many opened files bug
            self.images.append(image.copy())
            image.close()
            self.labels.append(label)

    def __getitem__(self, index):
        """ Get a sample from the dataset
        """
        if self.images is not None:
            # If dataset is preloaded
            image = self.images[index]
            label = self.labels[index]
        else:
            # If on-demand data loading
            image_fn, label = self.filenames[index]
            image = Image.open(image_fn)
            
        # May use transform function to transform samples
        # e.g., random crop, whitening
        if self.transform is not None:
            image = self.transform(image)
            #image = transform.resize(image, (300,300))
            #image = self.transform(image)
        # return image and label
        return image, label

    def __len__(self):
        """
        Total number of samples in the dataset
        """
        return self.len
    

In [5]:
# Create the PDBSS dataset. 
# transforms.ToTensor() automatically converts PIL images to
# torch tensors with range [0, 1]

data_transform = transforms.Compose([
        transforms.Resize(size=(300,300)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.0, 0.0, 0.0],
                             std=[0.25, 0.25, 0.25])
    ])

trainset = PDBSS_ATOMS(
    root='/Users/user/cs231n_Database_dir/CS231N_PDB_SNAPSHOTS_DATABASE_5000_ATOMS/train',
    preload=False, transform=data_transform,
)
# Use the torch dataloader to iterate through the dataset
trainset_loader = DataLoader(trainset, batch_size=64*4, shuffle=True, num_workers=2)  #batch_size = original 64. ideally > 740.

# load the testset
testset = PDBSS_ATOMS(
    root='/Users/user/cs231n_Database_dir/CS231N_PDB_SNAPSHOTS_DATABASE_5000_ATOMS/test',
    preload=False, transform=data_transform,
)
# Use the torch dataloader to iterate through the dataset
testset_loader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) #batch_size = original 100

In [6]:
print(len(trainset))
print(len(testset))
testset[0][0].size()

305600
20736


torch.Size([3, 300, 300])

### Define a Conv Net


In [7]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=3, stride=1,padding=1) #(300+1x2-3)/1+1=300  300x300x10 pool 150x150x10
        self.conv2 = nn.Conv2d(10, 20, kernel_size=3, stride=1,padding=1) #(150+1x2-3)/1+1=150  150x150x10 pool 75x75x20
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(75*75*20, 740*3)  #75x75x20 
        self.fc2 = nn.Linear(740*3, 740) #740 total categories

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 75*75*20) # reshape 75x75x20
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = Net().to(device)
#optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [8]:
def test():
    model.eval()  # set evaluation mode
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in testset_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            
            print('{} {}'.format(pred, target.view_as(pred)))
            print('{}'.format(pred.eq(target.view_as(pred)).sum().item()))
 
            #if(pred.eq(target.view_as(pred)).sum().item() == 1):
            #    print('{} correct'.format(target))
            #else:
            #    print('{} fail'.format(target))

    test_loss /= len(testset_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(testset_loader.dataset),
        100. * correct / len(testset_loader.dataset)))
    
    

In [9]:
def see_train_set():
    model.eval()  # set evaluation mode
    trainset_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in trainset_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            trainset_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            
            print('{} {}'.format(pred, target.view_as(pred)))
            print('{}'.format(pred.eq(target.view_as(pred)).sum().item()))
 
            #if(pred.eq(target.view_as(pred)).sum().item() == 1):
            #    print('{} correct'.format(target))
            #else:
            #    print('{} fail'.format(target))

    trainset_loss /= len(trainset_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        trainset_loss, correct, len(trainset_loader.dataset),
        100. * correct / len(trainset_loader.dataset)))

## 2. Save the model (model checkpointing)

Now we have a model! Obviously we do not want to retrain the model everytime we want to use it. Plus if you are training a super big model, you probably want to save checkpoint periodically so that you can always fall back to the last checkpoint in case something bad happened or you simply want to test models at different training iterations.

Model checkpointing is fairly simple in PyTorch. First, we define a helper function that can save a model to the disk

In [10]:
def save_checkpoint(checkpoint_path, model, optimizer):
    state = {'state_dict': model.state_dict(),
             'optimizer' : optimizer.state_dict()}
    torch.save(state,checkpoint_path)
    print('model saved to %s' % checkpoint_path)
    
def load_checkpoint(checkpoint_path, model, optimizer):
    state = torch.load(checkpoint_path)
    model.load_state_dict(state['state_dict'])
    optimizer.load_state_dict(state['optimizer'])
    print('model loaded from %s' % checkpoint_path)
    

### Define a training loop with model checkpointing

In [11]:
# create a new model
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# load from the final checkpoint
load_checkpoint('model_training99p_test14p/pdbss-ATOMS-6000.pth', model, optimizer)
# should give you the final model accuracy
see_train_set()

model loaded from model_training99p_test14p/pdbss-ATOMS-6000.pth
tensor([[ 517],
        [ 504],
        [ 498],
        [ 446],
        [ 611],
        [ 359],
        [ 678],
        [ 139],
        [ 359],
        [ 359],
        [ 586],
        [ 299],
        [ 635],
        [ 359],
        [ 678],
        [ 678],
        [ 661],
        [ 286],
        [ 504],
        [  96],
        [ 127],
        [ 516],
        [ 678],
        [ 345],
        [ 504],
        [ 359],
        [ 178],
        [ 714],
        [ 504],
        [ 359],
        [ 293],
        [ 678],
        [ 635],
        [ 395],
        [ 133],
        [ 734],
        [ 504],
        [ 526],
        [ 586],
        [ 359],
        [ 449],
        [  77],
        [ 611],
        [ 324],
        [ 175],
        [ 330],
        [ 373],
        [ 656],
        [ 433],
        [ 293],
        [ 678],
        [ 624],
        [ 549],
        [ 373],
        [ 504],
        [ 391],
        [ 313],
        [ 558],
       

tensor([[ 373],
        [ 661],
        [ 153],
        [ 678],
        [ 678],
        [ 611],
        [ 586],
        [ 558],
        [ 359],
        [ 611],
        [ 634],
        [ 536],
        [ 359],
        [ 359],
        [ 611],
        [ 359],
        [ 611],
        [ 359],
        [ 373],
        [  13],
        [ 151],
        [ 504],
        [ 480],
        [ 683],
        [ 373],
        [ 373],
        [ 345],
        [ 678],
        [ 661],
        [ 449],
        [ 726],
        [ 373],
        [ 516],
        [ 449],
        [ 405],
        [ 450],
        [ 173],
        [ 373],
        [  51],
        [ 119],
        [ 395],
        [ 127],
        [ 466],
        [ 151],
        [ 714],
        [ 689],
        [  96],
        [ 373],
        [ 713],
        [ 295],
        [ 611],
        [ 154],
        [ 362],
        [ 359],
        [ 268],
        [ 427],
        [ 359],
        [ 386],
        [ 359],
        [ 117],
        [ 529],
        [ 148],
        

tensor([[ 153],
        [ 345],
        [ 635],
        [ 549],
        [ 313],
        [ 273],
        [ 359],
        [ 359],
        [ 373],
        [ 663],
        [ 521],
        [ 395],
        [ 362],
        [ 516],
        [ 409],
        [ 678],
        [ 452],
        [ 313],
        [ 373],
        [ 504],
        [ 501],
        [ 327],
        [  16],
        [ 491],
        [ 327],
        [ 299],
        [ 569],
        [ 430],
        [ 413],
        [ 453],
        [ 359],
        [ 678],
        [ 373],
        [ 119],
        [ 678],
        [ 678],
        [ 685],
        [ 635],
        [  51],
        [ 373],
        [ 268],
        [ 359],
        [ 678],
        [ 331],
        [ 611],
        [ 466],
        [ 373],
        [ 702],
        [ 586],
        [ 373],
        [ 359],
        [ 395],
        [ 345],
        [ 204],
        [ 342],
        [ 678],
        [ 539],
        [ 452],
        [ 678],
        [ 569],
        [ 685],
        [ 130],
        

Process Process-2:
Process Process-1:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/user/anaconda3/envs/cs231n/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Users/user/anaconda3/envs/cs231n/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Users/user/anaconda3/envs/cs231n/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/user/anaconda3/envs/cs231n/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/user/anaconda3/envs/cs231n/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 52, in _worker_loop
    r = index_queue.get()
  File "/Users/user/anaconda3/envs/cs231n/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 52, in _worker_loop
    r = index_queue.get()
  File "/Users/user/anaconda3/envs/cs231n/

KeyboardInterrupt: 

In [12]:
test()

tensor([[ 444],
        [ 480],
        [ 504],
        [ 359],
        [ 226],
        [ 373],
        [ 504],
        [ 714],
        [ 224],
        [ 275],
        [ 477],
        [ 359],
        [ 359],
        [ 425],
        [ 678],
        [ 359],
        [ 477],
        [ 432],
        [ 226],
        [ 735],
        [ 226],
        [ 226],
        [ 480],
        [ 678],
        [ 226],
        [ 226],
        [ 678],
        [ 359],
        [ 359],
        [ 359],
        [  51],
        [ 226],
        [ 388],
        [ 226],
        [ 678],
        [ 452],
        [ 273],
        [ 359],
        [ 345],
        [ 359],
        [ 226],
        [ 359],
        [ 226],
        [  51],
        [ 359],
        [ 226],
        [  51],
        [ 558],
        [ 359],
        [ 359],
        [ 178],
        [ 475],
        [ 226],
        [ 480],
        [ 388],
        [ 714],
        [ 714],
        [ 359],
        [ 345],
        [ 714],
        [ 611],
        [ 714],
        

tensor([[ 611],
        [ 529],
        [ 273],
        [ 549],
        [  16],
        [  16],
        [ 273],
        [ 273],
        [ 395],
        [ 273],
        [ 273],
        [  16],
        [ 193],
        [ 324],
        [ 273],
        [ 273],
        [ 395],
        [ 395],
        [  16],
        [ 549],
        [ 395],
        [ 395],
        [ 359],
        [ 373],
        [ 395],
        [ 395],
        [ 395],
        [ 359],
        [ 332],
        [ 678],
        [ 678],
        [ 395],
        [ 395],
        [ 395],
        [ 395],
        [ 395],
        [ 678],
        [ 504],
        [ 395],
        [ 678],
        [ 678],
        [ 661],
        [ 678],
        [ 504],
        [ 678],
        [ 714],
        [ 395],
        [ 678],
        [ 678],
        [ 678],
        [ 714],
        [ 395],
        [ 359],
        [ 504],
        [ 395],
        [ 359],
        [ 678],
        [ 678],
        [ 439],
        [ 714],
        [ 558],
        [ 395],
        

tensor([[  42],
        [  42],
        [  42],
        [  42],
        [ 373],
        [ 678],
        [ 504],
        [ 395],
        [ 359],
        [  42],
        [  42],
        [  42],
        [  42],
        [ 678],
        [ 714],
        [  42],
        [  42],
        [  42],
        [  42],
        [  42],
        [  42],
        [  42],
        [ 678],
        [ 532],
        [  42],
        [  42],
        [ 678],
        [ 678],
        [ 635],
        [ 341],
        [ 678],
        [ 310],
        [ 661],
        [  42],
        [  42],
        [ 359],
        [  42],
        [  42],
        [  42],
        [  42],
        [ 487],
        [ 425],
        [ 425],
        [ 359],
        [ 425],
        [ 424],
        [ 425],
        [ 425],
        [ 395],
        [ 425],
        [ 425],
        [ 678],
        [ 359],
        [ 425],
        [ 425],
        [ 425],
        [ 126],
        [ 395],
        [ 395],
        [ 487],
        [ 126],
        [ 635],
        

tensor([[  83],
        [ 359],
        [  83],
        [  83],
        [  83],
        [  83],
        [  83],
        [  83],
        [  83],
        [  83],
        [  83],
        [  83],
        [  83],
        [ 359],
        [ 373],
        [  83],
        [  83],
        [ 678],
        [  83],
        [ 726],
        [ 273],
        [  83],
        [ 691],
        [  83],
        [ 395],
        [  83],
        [  83],
        [  83],
        [ 359],
        [  83],
        [ 691],
        [  83],
        [  83],
        [ 359],
        [  83],
        [ 359],
        [ 359],
        [  83],
        [ 678],
        [ 359],
        [ 504],
        [ 359],
        [  83],
        [  83],
        [  83],
        [ 359],
        [  83],
        [ 359],
        [  83],
        [  83],
        [  83],
        [  83],
        [  83],
        [ 678],
        [  83],
        [  83],
        [  83],
        [ 359],
        [  83],
        [  83],
        [ 359],
        [ 359],
        

tensor([[ 549],
        [ 373],
        [ 178],
        [ 178],
        [ 178],
        [ 359],
        [ 586],
        [ 373],
        [ 549],
        [ 178],
        [ 178],
        [ 313],
        [ 635],
        [ 178],
        [ 178],
        [ 635],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [ 504],
        [ 504],
        [  79],
        [  79],
        [  79],
        [ 678],
        [ 504],
        [ 504],
        [ 504],
        [  79],
        [  79],
        [ 504],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        [  79],
        

tensor([[ 683],
        [ 151],
        [ 151],
        [ 151],
        [ 324],
        [ 151],
        [ 151],
        [ 151],
        [ 683],
        [ 683],
        [ 151],
        [ 151],
        [ 446],
        [ 683],
        [ 151],
        [ 151],
        [ 324],
        [ 683],
        [ 683],
        [ 151],
        [ 341],
        [ 625],
        [ 151],
        [ 504],
        [ 327],
        [ 324],
        [ 611],
        [ 151],
        [ 151],
        [ 151],
        [ 683],
        [ 683],
        [ 151],
        [ 359],
        [ 151],
        [ 151],
        [ 178],
        [ 678],
        [ 359],
        [  51],
        [ 545],
        [ 573],
        [ 359],
        [ 586],
        [ 427],
        [ 130],
        [ 446],
        [ 178],
        [ 689],
        [ 273],
        [ 395],
        [ 395],
        [ 446],
        [ 332],
        [ 395],
        [ 549],
        [ 678],
        [ 678],
        [ 428],
        [ 395],
        [ 678],
        [ 678],
        

tensor([[ 395],
        [ 404],
        [ 678],
        [ 359],
        [ 714],
        [ 525],
        [ 359],
        [ 359],
        [ 359],
        [ 678],
        [ 678],
        [ 359],
        [ 345],
        [ 456],
        [ 480],
        [ 359],
        [ 395],
        [ 178],
        [ 359],
        [ 359],
        [ 373],
        [ 359],
        [ 359],
        [ 345],
        [ 359],
        [ 678],
        [ 359],
        [ 359],
        [ 345],
        [ 359],
        [ 359],
        [ 678],
        [ 678],
        [  51],
        [ 678],
        [ 504],
        [ 359],
        [ 678],
        [ 373],
        [ 678],
        [ 359],
        [ 173],
        [ 661],
        [ 678],
        [  51],
        [ 359],
        [ 359],
        [ 678],
        [ 661],
        [ 359],
        [ 373],
        [ 345],
        [ 359],
        [  51],
        [ 714],
        [ 600],
        [ 359],
        [ 359],
        [ 661],
        [ 370],
        [ 359],
        [ 504],
        

tensor([[ 273],
        [ 193],
        [ 359],
        [ 273],
        [ 675],
        [ 675],
        [ 327],
        [ 324],
        [ 359],
        [ 586],
        [ 344],
        [ 678],
        [ 395],
        [ 359],
        [ 359],
        [ 345],
        [ 504],
        [ 328],
        [ 395],
        [ 395],
        [ 173],
        [ 395],
        [ 395],
        [ 498],
        [ 395],
        [ 395],
        [ 678],
        [ 504],
        [ 395],
        [ 678],
        [ 437],
        [ 678],
        [ 395],
        [ 395],
        [ 395],
        [ 395],
        [ 151],
        [ 359],
        [ 359],
        [ 359],
        [ 395],
        [ 678],
        [ 359],
        [ 395],
        [ 678],
        [ 635],
        [ 395],
        [ 678],
        [ 395],
        [ 678],
        [ 395],
        [ 446],
        [ 611],
        [ 678],
        [ 395],
        [ 395],
        [ 395],
        [ 359],
        [  27],
        [ 452],
        [ 395],
        [  27],
        

tensor([[ 286],
        [ 611],
        [ 359],
        [ 344],
        [ 334],
        [ 356],
        [ 356],
        [ 286],
        [ 133],
        [ 133],
        [ 293],
        [ 611],
        [ 611],
        [ 293],
        [ 286],
        [ 611],
        [ 645],
        [ 286],
        [ 645],
        [ 611],
        [ 286],
        [ 645],
        [ 293],
        [ 286],
        [ 466],
        [ 334],
        [ 286],
        [ 293],
        [ 611],
        [ 645],
        [ 645],
        [ 133],
        [ 370],
        [ 678],
        [ 395],
        [ 717],
        [ 273],
        [ 353],
        [ 273],
        [  51],
        [ 678],
        [ 273],
        [ 611],
        [ 327],
        [ 359],
        [ 273],
        [ 678],
        [ 273],
        [ 678],
        [ 678],
        [ 193],
        [ 598],
        [ 466],
        [ 466],
        [ 273],
        [ 273],
        [ 678],
        [ 678],
        [ 273],
        [ 177],
        [ 477],
        [ 466],
        

tensor([[ 345],
        [ 273],
        [ 466],
        [ 193],
        [ 678],
        [ 193],
        [ 586],
        [ 344],
        [ 678],
        [ 273],
        [ 134],
        [ 477],
        [ 586],
        [ 345],
        [ 539],
        [ 345],
        [ 134],
        [ 414],
        [ 414],
        [ 134],
        [ 345],
        [ 273],
        [ 134],
        [  33],
        [ 273],
        [ 133],
        [ 273],
        [ 273],
        [ 273],
        [ 193],
        [ 253],
        [ 273],
        [ 602],
        [  51],
        [ 678],
        [ 324],
        [ 134],
        [ 539],
        [ 273],
        [  51],
        [ 257],
        [ 345],
        [ 273],
        [ 466],
        [ 134],
        [ 273],
        [  33],
        [ 678],
        [ 446],
        [ 273],
        [ 134],
        [ 345],
        [ 678],
        [ 373],
        [ 395],
        [ 504],
        [ 678],
        [ 230],
        [ 678],
        [ 425],
        [ 395],
        [ 453],
        

tensor([[ 634],
        [ 134],
        [ 345],
        [ 293],
        [ 134],
        [ 714],
        [ 134],
        [ 134],
        [ 735],
        [ 373],
        [ 504],
        [ 661],
        [ 395],
        [ 345],
        [ 678],
        [ 202],
        [ 551],
        [ 504],
        [ 359],
        [ 678],
        [ 359],
        [ 504],
        [ 345],
        [ 395],
        [ 504],
        [ 395],
        [ 273],
        [ 678],
        [ 735],
        [ 563],
        [ 395],
        [ 423],
        [ 423],
        [ 395],
        [ 504],
        [ 678],
        [ 678],
        [ 678],
        [ 504],
        [ 678],
        [ 395],
        [ 504],
        [ 504],
        [ 139],
        [ 661],
        [ 635],
        [  51],
        [ 678],
        [ 678],
        [ 678],
        [ 678],
        [ 504],
        [ 178],
        [ 395],
        [ 635],
        [ 661],
        [ 273],
        [ 678],
        [ 398],
        [ 345],
        [ 678],
        [ 678],
        

tensor([[ 373],
        [ 373],
        [ 153],
        [ 373],
        [ 373],
        [ 373],
        [ 373],
        [ 385],
        [ 373],
        [ 373],
        [ 373],
        [ 373],
        [ 373],
        [ 373],
        [ 373],
        [ 373],
        [ 373],
        [ 373],
        [ 373],
        [ 373],
        [ 373],
        [ 373],
        [ 373],
        [ 359],
        [ 373],
        [ 373],
        [ 359],
        [ 373],
        [ 327],
        [ 200],
        [ 189],
        [ 324],
        [ 324],
        [ 649],
        [ 549],
        [ 134],
        [ 159],
        [ 324],
        [ 661],
        [ 200],
        [ 661],
        [ 661],
        [ 327],
        [ 133],
        [ 133],
        [ 327],
        [ 293],
        [ 293],
        [ 293],
        [ 293],
        [ 446],
        [ 293],
        [ 133],
        [  51],
        [ 293],
        [ 539],
        [ 293],
        [ 324],
        [ 635],
        [ 635],
        [ 133],
        [ 549],
        

Process Process-4:
Process Process-3:
Traceback (most recent call last):
  File "/Users/user/anaconda3/envs/cs231n/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Users/user/anaconda3/envs/cs231n/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/user/anaconda3/envs/cs231n/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 52, in _worker_loop
    r = index_queue.get()
  File "/Users/user/anaconda3/envs/cs231n/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
Traceback (most recent call last):
  File "/Users/user/anaconda3/envs/cs231n/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/Users/user/anaconda3/envs/cs231n/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Users/user/anaconda3/envs/cs231n/lib/python3.6/

KeyboardInterrupt: 