In [1]:
!gdown --id 1-5EJI4c44Ju7dSi77QPW6bV5O16hcwdW
!unzip archive.zip

Downloading...
From: https://drive.google.com/uc?id=1-5EJI4c44Ju7dSi77QPW6bV5O16hcwdW
To: /content/archive.zip
100% 436M/436M [00:02<00:00, 158MB/s]
Archive:  archive.zip
replace annotations/defect_free/1.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
  inflating: annotations/defect_free/1.txt  
  inflating: annotations/defect_free/10.txt  
  inflating: annotations/defect_free/11.txt  
  inflating: annotations/defect_free/12.txt  
  inflating: annotations/defect_free/13.txt  
  inflating: annotations/defect_free/14.txt  
  inflating: annotations/defect_free/15.txt  
  inflating: annotations/defect_free/16.txt  
  inflating: annotations/defect_free/17.txt  
  inflating: annotations/defect_free/18.txt  
  inflating: annotations/defect_free/19.txt  
  inflating: annotations/defect_free/2.txt  
  inflating: annotations/defect_free/20.txt  
  inflating: annotations/defect_free/21.txt  
  inflating: annotations/defect_free/22.txt  
  inflating: annotations/defect_free/23.txt  
  inflating: ann

In [2]:
!pip install pytorch-lightning
!pip install -q pytorch-metric-learning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
from torch.utils.data import Dataset
import os
import numpy as np
from PIL import Image
from torchvision.transforms import Compose

class StainDataset(Dataset):
  n_classes = 2
  class_names = ["stain", "defect free"]
  class_folder_names = {"stain": "stain", "defect free": "defect_free"}

  def __init__(self, path:str, transform:Compose):
    super().__init__()
    assert os.path.exists(path)
    self.path = path
    self.transform = transform
    self.images = []

    for class_name in self.class_names:
      class_folder_path = os.path.join(self.path, self.class_folder_names[class_name])
      assert os.path.exists(class_folder_path)
      for image_file in os.scandir(path=class_folder_path):
        self.images.append((image_file, self.class_names.index(class_name)))

  def __len__(self):
    return len(self.images)
    
  def __getitem__(self, ndx):
    image = np.asarray(Image.open(os.path.join(self.images[ndx][0])))
    if self.transform is not None:
      image = self.transform(image)
    return image, self.images[ndx][1]

In [4]:
import pytorch_lightning as pl
from torch.nn import Module
from pytorch_metric_learning.miners import BatchHardMiner
from pytorch_metric_learning.losses import TripletMarginLoss
from torch.optim import Adam

class ModelLearner(pl.LightningModule):
    def __init__(self, model:Module, margin=0.2):
        super().__init__()

        self.save_hyperparameters()
        self.model = model
        self.margin = margin
        
        self.miner = BatchHardMiner()
        self.loss = TripletMarginLoss(margin=self.margin)
        
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, labels = batch
        embeddings = self.forward(x)
        triplets = self.miner(embeddings, labels)
        loss = self.loss(embeddings, labels, triplets)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, labels = batch
        embeddings = self.forward(x)
        triplets = self.miner(embeddings, labels)
        loss = self.loss(embeddings, labels, triplets)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

    def configure_optimizers(self):
      return Adam(self.parameters(), lr=1e-4)

In [6]:
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize, ToPILImage, Resize
import torch
from torch import nn
import torch.nn.functional as F

class MLPClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(32*32, 128)
        self.layer_2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        return x

dataset = StainDataset(path="images", transform=Compose([ToPILImage(), Resize((32, 32)), ToTensor()]))
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
model = ModelLearner(MLPClassifier())

  rank_zero_warn(


In [7]:
import pytorch_lightning as pl

trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
  rank_zero_warn(
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type              | Params
--------------------------------------------
0 | model | MLPClassifier     | 132 K 
1 | miner | BatchHardMiner    | 0     
2 | loss  | TripletMarginLoss | 0     
--------------------------------------------
132 K     Trainable params
0         Non-trainable params
132 K     Total params
0.530     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
