In [59]:
from fastai.vision import *
from torch.utils.data import Dataset, DataLoader, RandomSampler

from fastai.core import ItemBase
from fastai.basic_data import DataBunch

from sklearn.model_selection import train_test_split

import pandas as pd
import numpy as np

In [3]:
### Read in a dataframe that contains all images, classes, and their dimensions, created earlier ###
tr = pd.read_csv('../exploratory_analysis/image_dims.csv')

We don't want to include the `new_whale` label in training, because we really can't establish if there happen to be any pairs existing in the dataset already.

In [4]:
### Get counts of the various labels, which returns in descending count order ###
wcts = tr['Id'].value_counts()

### Grab everything except for the most numerous, which is the new_whale label ###
nonew = pd.DataFrame(wcts[1:]).reset_index()
nonew.columns = ['Id', 'cts']

In [5]:
### We extract the list of images that can be used as same class and those that cannnot ####
same = nonew[nonew['cts']>1]['Id'].values
diff = nonew[nonew['cts']==1]['Id'].values

In [38]:
class SiameseDataset(Dataset):
    def __init__(self, data_path, df, same, diff):
        self.c = 1
        self.data_path = data_path
        self.df = df
        self.same_groups = same
        self.diff_groups = diff
        
    def __len__(self):
        return len(diff)
        
    def __getitem__(self, index):
        ### We will need to draw pairs of similar images as often as dissimilar images ###
        ### This will also function as the label we pass ###
        similar = np.random.choice([0 , 1])
        
        ### To draw two dissimilar items ###
        if similar==0:
            ### Draw items from two different objects ###
            ind1, ind2 = np.random.choice(self.diff_groups, size = 2, replace=False)
            
            img1 = np.random.choice(self.df[self.df['Id']==ind1]['Image'].values)
            img2 = np.random.choice(self.df[self.df['Id']==ind2]['Image'].values)

        elif similar==1:
            ind = np.random.choice(self.same_groups)
            
            img1, img2 = np.random.choice(self.df[self.df['Id']==ind]['Image'].values, size=2, replace=False)
        
        img1 = open_image(self.data_path + img1)
        img2 = open_image(self.data_path + img2)
        
        return img1, img2, similar

In [39]:
class SiameseDataLoader(DataLoader):
    def __init__(self, data_path, df, same, diff, **kwargs):
        super().__init__(self, **kwargs)
        self.same = same
        self.diff = diff
        self.df = df

In [None]:
class SiameseDataBunch(DataBunch):
    

In [51]:
# sd = SiameseDataset('../../train/', tr, same, diff)

In [41]:
# img1, img2, pair_class = sd[0]

# img1.show(); img2.show(); print(pair_class)

In [54]:
same_train, same_val = train_test_split(same, test_size=0.2)
diff_train, diff_val = train_test_split(diff, test_size=0.2)

In [56]:
train_dataset = SiameseDataset('../../train/', tr, same_train, diff_train)
val_dataset = SiameseDataset('../../train/', tr, same_val, diff_val)

train_dataloader = DataLoader(train_dataset)
val_dataloader = DataLoader(val_dataset)

## To Do: 
Will need to define the ability to do transforms, and how to actually use the batch size, etc.

In [None]:
ImageDataBunch()

In [47]:
train_sampler = RandomSampler(sd)
sdl = SiameseDataLoader('../../train/', tr, same, diff, sampler=train_sampler)

In [48]:
len(sdl)

2073

In [50]:
next(sdl)

TypeError: 'SiameseDataLoader' object is not an iterator

In [14]:
tr.head()

Unnamed: 0,Image,Id,x,y,channels
0,0000e88ab.jpg,w_f48451c,700,1050,3
1,0001f9222.jpg,w_c3d896a,325,758,3
2,00029d126.jpg,w_20df2c5,497,1050,3
3,00050a15a.jpg,new_whale,525,1050,3
4,0005c1ef8.jpg,new_whale,525,1050,3


In [35]:
np.random.choice(tr[tr['Id']=='w_f48451c']['Image'].values)

'e2f1b6c4a.jpg'

In [37]:
tst1, tst2 = np.random.choice(tr[tr['Id']=='w_f48451c']['Image'].values, size=2)

In [39]:
tst1, tst2

('0000e88ab.jpg', '9fc84d2ae.jpg')

In [58]:
class ImageTuple(ItemBase):
    def __init__(self, img1, img2, similarity_label):
        self.img1,self.img2 = img1,img2
        self.obj,self.data = (img1,img2),[similarity_label]
        
    def apply_tfms(self, tfms, **kwargs):
        self.img1 = self.img1.apply_tfms(tfms, **kwargs)
        self.img2 = self.img2.apply_tfms(tfms, **kwargs)
        return self

In [None]:
class ImageTupleList(ImageItemList):
    def __init__(self, items, itemsB=None, **kwargs):
        self.itemsB = itemsB
        super().__init__(items, **kwargs)
    
    def new(self, items, **kwargs):
        return super().new(items, itemsB=self.itemsB, **kwargs)
    
    def get(self, i):
        img1 = super().get(i)
        fn = self.itemsB[random.randint(0, len(self.itemsB)-1)]
        return ImageTuple(img1, open_image(fn))
    
    @classmethod
    def from_folders(cls, path, folderA, folderB, **kwargs):
        itemsB = ImageItemList.from_folder(path/folderB).items
        res = super().from_folder(path/folderA, itemsB=itemsB, **kwargs)
        res.path = path
        return res

    def reconstruct(self, t:Tensor): 
        return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5))

class TargetTupleList(ItemList):
    def reconstruct(self, t:Tensor): 
        if len(t.size()) == 0: return t
        return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5))

class ImageTupleList(ImageItemList):
    _label_cls=TargetTupleList
    def __init__(self, items, itemsB=None, **kwargs):
        self.itemsB = itemsB
        super().__init__(items, **kwargs)
    
    def new(self, items, **kwargs):
        return super().new(items, itemsB=self.itemsB, **kwargs)
    
    def get(self, i):
        img1 = super().get(i)
        fn = self.itemsB[random.randint(0, len(self.itemsB)-1)]
        return ImageTuple(img1, open_image(fn))
    
    def reconstruct(self, t:Tensor): 
        return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5))
    
    @classmethod
    def from_folders(cls, path, folderA, folderB, **kwargs):
        itemsB = ImageItemList.from_folder(path/folderB).items
        res = super().from_folder(path/folderA, itemsB=itemsB, **kwargs)
        res.path = path
        return res

    def show_xys(self, xs, ys, figsize:Tuple[int,int]=(12,6), **kwargs):
        "Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method."
        rows = int(math.sqrt(len(xs)))
        fig, axs = plt.subplots(rows,rows,figsize=figsize)
        for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):
            xs[i].to_one().show(ax=ax, **kwargs)
        plt.tight_layout()

    def show_xyzs(self, xs, ys, zs, figsize:Tuple[int,int]=None, **kwargs):
        """Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`.
        `kwargs` are passed to the show method."""
        figsize = ifnone(figsize, (12,3*len(xs)))
        fig,axs = plt.subplots(len(xs), 2, figsize=figsize)
        fig.suptitle('Ground truth / Predictions', weight='bold', size=14)
        for i,(x,z) in enumerate(zip(xs,zs)):
            x.to_one().show(ax=axs[i,0], **kwargs)
            z.to_one().show(ax=axs[i,1], **kwargs)