In [161]:
from __future__ import print_function, division
import os
import torch
import json
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

In [162]:
class IMaterialDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        url = self.df['images'].iloc[idx]
        image = io.imread(url)
        label = self.df['labels'].iloc[idx]
        sample = {'image': image, 'label': label}
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample
        

In [1]:
class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (75, 75), mode='constant')

        return {'image': img, 'label': label}
    
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        
        return {'image': torch.from_numpy(image),
                'label': torch.from_numpy(np.array([label]))}
    

In [163]:
path = '/home/edwin/Datasets/competitions/imaterialist-challenge-furniture-2018/'

In [84]:
train_path = f'{path}train.json'

In [None]:
with open(train_path) as d1:
    train_json = json.load(d1)

In [None]:
len(train_json['images'][0:100])

In [None]:
len(train_json['images'])

In [None]:
len(train_json['annotations'])

In [None]:
images_sample = train_json['images'][0:100]
annotations_sample = train_json['annotations'][0:100]

In [None]:
train_json.keys()

In [None]:
data = {'images': [], 'annotations': []}

In [None]:
df = pd.DataFrame(data=data)

In [None]:
df

In [None]:
images_sample[0]

In [None]:
def get_image_id(x): return x['image_id']
def get_url(x): return x['url'][0]
def get_label_id(x): return x['label_id']

In [None]:
get_url(images_sample[0])

In [None]:
get_label_id(annotations_sample[0])

In [None]:
ids = list(map(get_image_id, images_sample))

In [None]:
images = list(map(get_url, images_sample))

In [None]:
labels = list(map(get_label_id, annotations_sample))

In [None]:
data = {'id': ids, 'images': images, 'labels': labels}

In [None]:
df = pd.DataFrame(data=data)

In [None]:
df = df.set_index('id')

In [None]:
df.images.iloc[0]

In [None]:
df.images.iloc[1]

In [None]:
df_sample = df

In [None]:
df_sample.to_pickle('iDataSample.pkl')

In [85]:
df2 = pd.read_pickle('iDataSample.pkl')

In [86]:
imDataset = IMaterialDataset(df2)

In [None]:
df2['images'].iloc[1]

In [None]:
imDataset[1]

In [None]:
imDataset[2]

In [None]:
for i in range(len(imDataset)):
    sample = imDataset[i]
    print(i)

    print(i, sample['image'].shape, sample['label'].shape)
    if i == 3:
        break

In [None]:
sample['image'].shape

In [164]:
class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (75, 75), mode='constant')

        return {'image': img, 'label': label}

In [88]:
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'label': torch.from_numpy(np.array([label]))}

In [165]:
composed = transforms.Compose([Rescale(50), ToTensor()])

In [166]:
transformed_dataset = IMaterialDataset(df=df2, transform=composed)

In [67]:
for i in range(len(transformed_dataset)):
    print(transformed_dataset[i]['image'].shape)

torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 75, 75])
torch.Size([3, 7

In [48]:
a = transformed_dataset[0]

In [167]:
dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=1)

In [92]:
for i_batch in enumerate(dataloader):
    print(i_batch)

(0, {'image': 
(0 ,0 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  0.9961  1.0000
  1.0000  1.0000  1.0000  ...   0.9973  0.9969  1.0000
  1.0000  1.0000  1.0000  ...   0.9961  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,1 ,.,.) = 
  0.9968  1.0000  1.0000  ...   0.9961  1.0000  1.0000
  0.9983  1.0000  1.0000  ...   0.9988  1.0000  1.0000
  0.9981  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,2 ,.,.) = 
  0.8086  0.8039  0.8039  ...   0.8353  0.8353  0.8471
  0.8056  0.8039  0.7965  ...   0.8362  0.8440  0.8503
  0.8086  0.8039  0.7972  ...   0.8431  0.8505  0.8510
   

(2, {'image': 
(0 ,0 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,1 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,2 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
   

(4, {'image': 
(0 ,0 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,1 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,2 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
   

(6, {'image': 
(0 ,0 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,1 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,2 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
   

(8, {'image': 
(0 ,0 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,1 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,2 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
   

(10, {'image': 
(0 ,0 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,1 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,2 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  

(12, {'image': 
(0 ,0 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,1 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,2 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  

(14, {'image': 
(0 ,0 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,1 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,2 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  

(16, {'image': 
(0 ,0 ,.,.) = 
  0.9451  0.9490  0.9490  ...   0.9647  0.9647  0.9647
  0.9490  0.9490  0.9490  ...   0.9650  0.9657  0.9667
  0.9490  0.9507  0.9490  ...   0.9686  0.9706  0.9686
           ...             ⋱             ...          
  0.9014  0.8866  0.8928  ...   0.9373  0.9451  0.9425
  0.8948  0.8951  0.8797  ...   0.9353  0.9441  0.9435
  0.8942  0.8902  0.8751  ...   0.9366  0.9412  0.9451

(0 ,1 ,.,.) = 
  0.9412  0.9451  0.9451  ...   0.9686  0.9686  0.9686
  0.9451  0.9451  0.9451  ...   0.9650  0.9676  0.9667
  0.9451  0.9467  0.9451  ...   0.9647  0.9667  0.9647
           ...             ⋱             ...          
  0.9014  0.8905  0.8928  ...   0.9373  0.9431  0.9425
  0.8987  0.8951  0.8837  ...   0.9353  0.9441  0.9435
  0.8942  0.8941  0.8829  ...   0.9366  0.9412  0.9451

(0 ,2 ,.,.) = 
  0.9647  0.9686  0.9686  ...   0.9843  0.9843  0.9843
  0.9686  0.9686  0.9686  ...   0.9827  0.9843  0.9843
  0.9686  0.9703  0.9686  ...   0.9843  0.9863  0.9843
  

(18, {'image': 
(0 ,0 ,.,.) = 
  0.2980  0.2824  0.4559  ...   0.4196  0.4363  0.4480
  0.2529  0.2382  0.2461  ...   0.1324  0.1304  0.2549
  0.2471  0.2235  0.2520  ...   0.1216  0.1088  0.2353
           ...             ⋱             ...          
  0.7196  0.7167  0.7196  ...   0.6157  0.7069  0.6922
  0.7333  0.7255  0.7176  ...   0.6667  0.6598  0.6500
  0.7353  0.7402  0.4922  ...   0.6833  0.7069  0.6608

(0 ,1 ,.,.) = 
  0.3667  0.3549  0.4696  ...   0.4294  0.4363  0.4402
  0.2765  0.2745  0.2892  ...   0.1755  0.1892  0.2804
  0.2892  0.2892  0.3127  ...   0.1647  0.1804  0.2667
           ...             ⋱             ...          
  0.7353  0.7284  0.7333  ...   0.4843  0.1373  0.2824
  0.7500  0.7510  0.7373  ...   0.2431  0.1618  0.2559
  0.7549  0.7441  0.4696  ...   0.1167  0.1706  0.2206

(0 ,2 ,.,.) = 
  0.6245  0.6314  0.5657  ...   0.4627  0.4775  0.4735
  0.5373  0.5461  0.4824  ...   0.5471  0.5284  0.5382
  0.6137  0.5922  0.5167  ...   0.5490  0.5235  0.5765
  

(20, {'image': 
(0 ,0 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,1 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,2 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  

(22, {'image': 
(0 ,0 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,1 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(0 ,2 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  

(24, {'image': 
(0 ,0 ,.,.) = 
  0.1961  0.2069  0.1014  ...   0.0000  0.0000  0.0000
  0.2422  0.2441  0.2003  ...   0.0000  0.0000  0.0000
  0.1045  0.0539  0.0670  ...   0.0000  0.0000  0.0000
           ...             ⋱             ...          
  0.0000  0.0020  0.0039  ...   0.0000  0.0000  0.0000
  0.0000  0.0010  0.0020  ...   0.0000  0.0000  0.0000
  0.0000  0.0000  0.0000  ...   0.0000  0.0016  0.0033

(0 ,1 ,.,.) = 
  0.1922  0.2029  0.0975  ...   0.0000  0.0000  0.0000
  0.2382  0.2402  0.1964  ...   0.0000  0.0000  0.0000
  0.1005  0.0539  0.0537  ...   0.0000  0.0000  0.0000
           ...             ⋱             ...          
  0.0000  0.0020  0.0039  ...   0.0039  0.0039  0.0039
  0.0000  0.0010  0.0020  ...   0.0039  0.0039  0.0039
  0.0000  0.0000  0.0000  ...   0.0039  0.0039  0.0039

(0 ,2 ,.,.) = 
  0.2157  0.2265  0.1171  ...   0.0000  0.0000  0.0000
  0.2618  0.2618  0.2160  ...   0.0000  0.0000  0.0000
  0.1202  0.0618  0.0522  ...   0.0000  0.0000  0.0000
  

In [268]:
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 200)
        self.fc2 = nn.Linear(200, 150)
        self.fc3 = nn.Linear(150, 128)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1.double()(x)))
        x = self.pool(F.relu(self.conv2.double()(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1.double()(x))
        x = F.relu(self.fc2.double()(x))
        x = self.fc3.double()(x)
        return x

NameError: name 'nn' is not defined

In [215]:
conv1 = nn.Conv2d(3, 6, 5)
pool = nn.MaxPool2d(2,2)
conv2 = nn.Conv2d(6, 16, 5)
fc1 = nn.Linear(16 * 5 * 5, 120)
fc2 = nn.Linear(120, 84)
fc3 = nn.Linear(84, 10)

In [232]:
x = x.view(-1, 16 * 5 * 5)

In [246]:
x.view(-1, 3600)

Variable containing:
 0.0125  0.0125  0.0136  ...   0.0193  0.0632  0.0813
 0.0333  0.0166  0.0284  ...   0.0000  0.0000  0.0000
 0.0125  0.0125  0.0125  ...   0.0856  0.0847  0.0827
 0.0932  0.0765  0.0499  ...   0.0295  0.0000  0.0000
[torch.DoubleTensor of size 4x3600]

In [235]:
F.relu(fc1.double()(x)).shape

torch.Size([36, 120])

In [218]:
x = pool(F.relu(conv1.double()(x)))

In [221]:
x = pool(F.relu(conv2.double()(x)))

Variable containing:
 5
 5
 5
 5
[torch.LongTensor of size 4]

In [257]:
criterion(outputs.view(-1, 36), labels.view(4))

Variable containing:
 3.5640
[torch.DoubleTensor of size 1]

In [231]:
x.shape

torch.Size([4, 16, 15, 15])

In [230]:
x.view(-1).shape

torch.Size([14400])

In [225]:
x.shape

torch.Size([4, 16, 15, 15])

In [224]:
x.view(-1, 16 * 5 * 5).shape

torch.Size([36, 400])

In [222]:
x.shape

torch.Size([4, 16, 15, 15])

In [219]:
x.shape

torch.Size([4, 6, 35, 35])

In [213]:
x.shape

torch.Size([4, 3, 75, 75])

In [211]:
x = Variable(sample['image'])

In [270]:
cnn = CNNModel()

In [201]:
cnn.parameters()

<generator object Module.parameters at 0x7fe71a7a9468>

In [271]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(cnn.parameters(), lr=0.001, momentum=0.9)

In [272]:
for i, data in enumerate(dataloader, 0):
    images = Variable(data['image'])
    labels = Variable(data['label'].view(4))
    
    optimizer.zero_grad()
    
    outputs = cnn(images)
    loss = criterion(outputs.view(-1, 36), labels.view(4))
    loss.backward()
    optimizer.step()
    if i == 1:
        break 

AttributeError: 'CNNModel' object has no attribute 'fc4'

In [267]:
_, predicted = torch.max(outputs.data, 1)

In [264]:
predicted


 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
 2
[torch.LongTensor of size 36]

In [138]:
import pdb

In [203]:
dataiter = iter(dataloader)

In [204]:
sample = dataiter.next()

In [205]:
sample['image'].shape, sample['label'].shape

(torch.Size([4, 3, 75, 75]), torch.Size([4, 1]))

In [206]:
sample['label'].view(4)


 5
 5
 5
 5
[torch.LongTensor of size 4]

In [207]:
outputs = cnn(Variable(sample['image']))

In [208]:
outputs.shape

torch.Size([36, 4])

In [189]:
images = Variable(sample['label'].view(4))

In [192]:
criterion(outputs, images)

RuntimeError: Assertion `THIndexTensor_(size)(target, 0) == batch_size' failed.  at /opt/conda/conda-bld/pytorch_1512386481460/work/torch/lib/THNN/generic/ClassNLLCriterion.c:79

In [191]:
optimizer.zero_grad()

In [195]:
sample['label'].view(1)

RuntimeError: invalid argument 2: size '[1]' is invalid for input with 4 elements at /opt/conda/conda-bld/pytorch_1512386481460/work/torch/lib/TH/THStorage.c:41

In [198]:
outputs.shape

torch.Size([36, 10])