In [1]:
from fastai.vision.all import *
import os

if torch.cuda.is_available() == False:
    raise ValueError("No CUDA device found!")
    
plt.style.use('dark_background')

path = untar_data("http://vision.ucsd.edu/datasets/yale_face_dataset_original/yalefaces.zip")

# redundant file
if os.path.exists(path/"subject01.glasses.gif"):
    os.remove(path/"subject01.glasses.gif") 
    
# incorrect naming    
if os.path.exists(path/"subject01.gif"):
    os.rename(path/"subject01.gif", path/"subject01.centerlight")

files = L(path.glob("subject*"))

In [2]:
class SiameseImage(fastuple):
    def show(self, ctx=None, **kwargs):
        img1, img2, same = self
        if not isinstance(img1, Tensor):
            if img2.size != img1.size: img2 = img2.resize(img1.size)
            t1, t2 = tensor(img1), tensor(img2)
            t1, t2 = t1.permute(2,0,1), t2.permute(2,0,1)
        else: t1, t2 = img1, img2
        line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)
        return show_image(torch.cat([t1, line, t2], dim=2), title=same, ctx=ctx)
    
class SiameseTransform(Transform):
    def __init__(self, files, label_func, splits):
        self.labels = files.map(label_func).unique()        
        self.lbl2files = {l : L(f for f in files if label_func(f) == l) for l in self.labels}
        self.label_func = label_func
        self.valid = {f: self._draw(f) for f in files[splits[1]]}
        
    def encodes(self, f):
        f2, t = self.valid.get(f, self._draw(f)) # calls draw() if key=f does not exist
        img1, img2 = PILImage.create(f), PILImage.create(f2)
        return SiameseImage(img1, img2, t)
    
    def _draw(self, f):
        same = random.random() < 0.5
        cls = self.label_func(f)
        if not same:
            cls = random.choice(L(l for l in self.labels if l != cls))
        return random.choice(self.lbl2files[cls]), same
              
class SiameseModel(Module):
    def __init__(self, encoder):
        self.encoder = encoder
        self.fc = nn.Linear(1024, 1)
        
    def forward(self, x1, x2):
        e1 = self.encoder(x1)
        e2 = self.encoder(x2)
        
        x = torch.abs(e1 - e2)
        x = self.fc(x)
        x = nn.Sigmoid()(x)
        
        return x        

In [3]:
def my_loss(out, target):
    return nn.BCELoss()(torch.squeeze(out, 1), target.float())

def my_accuracy(input, target):
    label = input > 0.5
    return (label.squeeze(1) == target).float().mean()

def label_func(fname):
    return re.match(r'^subject(.*)\.', fname.name).groups()[0]

# make a test set
# our test will have faces with glasses, these won't be in the training
test_files = L()
train_files = L()

for f in files:
    if ".glasses" in f.name:
        test_files.append(f)
    else:
        train_files.append(f)
            
splits = RandomSplitter()(train_files)
tfm = SiameseTransform(train_files, label_func, splits)
tls = TfmdLists(train_files, tfm, splits=splits)
dls = tls.dataloaders(after_item=[ToTensor], after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats)], bs=16)

encoder = nn.Sequential(
    create_body(resnet18, cut=-2),
    AdaptiveConcatPool2d(),
    nn.Flatten()
)

model = SiameseModel(encoder)

In [None]:
learn = Learner(dls, model, loss_func=my_loss, metrics=my_accuracy).to_fp16()
#learn.lr_find()
learn.fit(50, 1e-3)
learn.save("yale_face")
learn.recorder.plot_loss()

In [4]:
learn = Learner(dls, model, loss_func=my_loss, metrics=my_accuracy).to_fp16()
learn.model.cuda()
learn.load("yale_face")

<fastai.learner.Learner at 0x7f1ae94ebf90>

In [10]:
# For every test image we compare against all the training images.
# Each training image will vote based using the similarity score from the network.
# This is sort of like K nearest neighbour, where K is the number of training images.
# I'm doing this in a very slow nested loop to keep it simple!
correct = 0

for f1 in sorted(test_files):
    img1 = PILImage.create(f1)

    # disable gradients to save GPU memory!
    with torch.no_grad(): 
        x1 = ToTensor()(img1).cuda()
        x1 = IntToFloatTensor()(x1)
        x1 = Normalize.from_stats(*imagenet_stats)(x1)

    vote = {}
    
    for f2 in train_files:
        img2 = PILImage.create(f2)
        label = label_func(f2)
        
        with torch.no_grad():
            x2 = ToTensor()(img2).cuda()
            x2 = IntToFloatTensor()(x2)
            x2 = Normalize.from_stats(*imagenet_stats)(x2)
            
            out = learn.model(x1, x2)
            vote[label] = vote.get(label, 0.0) + out
        
    best_label = max(vote, key=vote.get)
    
    if label_func(f1) == best_label:
        correct += 1
        
    print(f"{f1.name} most similar to label {best_label}")
    
print(f"correct classification {correct}/{len(test_files)}")

subject01.glasses most similar to label 01
subject02.glasses most similar to label 02


KeyboardInterrupt: 