In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [51]:
Load_from_checkpoint = False
chkpt_path = 'drive/MyDrive/' +\
             '/Checkpoints_sorted/color_recognition/' +\
             'simple_colors/' +\
             'epoch=5-step=287.ckpt'

batch_size = 64

Train = True
epochs = 12
chkpts_upload_dir = 'drive/MyDrive/checkpoints/color_recognition/'

Show = False
show_examples = 30

model = None
dm = None

In [39]:
!cp -u /content/drive/MyDrive/777x20_simple_colors.zip ./
!cp /content/drive/MyDrive/777x20_color.csv ./
!unzip -q -n -j 777x20_simple_colors.zip -d setupdir/

In [None]:
!pip -q install torch
!pip -q install torchvision
!pip -q install pytorch-lightning

from torchvision import datasets, transforms, models

from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.metrics.functional import precision_recall
from pytorch_lightning.callbacks import ModelCheckpoint

import torch
from torch.nn import functional
from torch.utils.data import DataLoader, Dataset, random_split

import csv
import os
from PIL import Image

[K     |████████████████████████████████| 918 kB 12.0 MB/s 
[K     |████████████████████████████████| 829 kB 39.3 MB/s 
[K     |████████████████████████████████| 118 kB 48.9 MB/s 
[K     |████████████████████████████████| 636 kB 43.3 MB/s 
[K     |████████████████████████████████| 272 kB 52.0 MB/s 
[K     |████████████████████████████████| 1.3 MB 30.8 MB/s 
[K     |████████████████████████████████| 294 kB 46.2 MB/s 
[K     |████████████████████████████████| 142 kB 49.5 MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone


In [None]:
class ColorDataset(Dataset):
    def __init__(self, setupdir, labels_csv):

        self.color_list = []

        raw_color_by_filename = dict()
        with open(labels_csv, 'r') as csvfile:
            reader = csv.reader(csvfile, quotechar='"', quoting=csv.QUOTE_ALL)
            for row in reader:
                raw_color_by_filename[row[0]] = row[1:]
                for color in row[1:]:
                    if color not in self.color_list:
                        self.color_list.append(color)

        self.color_by_filename = dict()
        for img, raw_colors in raw_color_by_filename.items():
            colors = [int(color in raw_colors) for color in self.color_list]
            self.color_by_filename[img] = colors

        if setupdir[-1] != '/':
            setupdir += '/'

        self.labels = []
        for filename in os.listdir(setupdir):
            self.labels.append((setupdir + filename,
                               self.color_by_filename[filename]))

        self.tranform = None

    def __len__(self):

        return len(self.labels)
    
    def __getitem__(self, ind):
    
        img = datasets.folder.default_loader(self.labels[ind][0])
        if self.transform:
            img = self.transform(img)

        return (img, torch.tensor([float(x) for x in self.labels[ind][1]]))

In [None]:
class RecognizeColorDM(LightningDataModule):
    def __init__(self, setupdir, labels_csv, train_frac=0.9, seed=0, batch_size=64):
        
        super().__init__()

        self.batch_size = batch_size
        self.setupdir = setupdir
        self.labels_csv = labels_csv
        self.train_frac = train_frac
        self.seed = seed
        self.batch_size = batch_size
        
        # for testing purposes only
        self.paths = []
        self.test_mode = False

        self.transform = transforms.Compose([
              transforms.Resize(size=256),
              transforms.CenterCrop(size=224),
              transforms.ToTensor(),
              transforms.Normalize([0.485, 0.456, 0.406],
                                   [0.229, 0.224, 0.225])
        ])


    def setup(self):
        
        torch.manual_seed(self.seed)
        
        dataset = ColorDataset(self.setupdir, self.labels_csv)

        self.num_colors = len(dataset.color_list)
        
        set_len = len(dataset)
        train_len = int(set_len * self.train_frac)
        val_len = int(set_len * (1 - self.train_frac) / 2)
        test_len = set_len - train_len - val_len
        
        self.train, self.val, self.test = random_split(dataset, 
                                                      [train_len,
                                                       val_len,
                                                       test_len])
        self.train.dataset.transform = self.transform
        
        self.val.dataset.transform = self.transform
        
        self.test.dataset.transform = self.transform
        
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)

In [None]:
class RecognizeColorModel(LightningModule):
    def __init__(self, input_shape, num_classes,
                 learning_rate = 1e-4, batch_size=64):
        
        super().__init__()

        self.batch_size = batch_size
        
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.dim = input_shape
        self.num_classes = num_classes
        
        self.feature_extractor = models.resnet34(pretrained=True)
        self.feature_extractor.eval()
        
        n_sizes = self._get_conv_output(input_shape)
        self.classifier = torch.nn.Linear(n_sizes, num_classes)
        
        self.predictions = []

    def _get_conv_output(self, shape):
        
        batch_size = 1
        inp = torch.autograd.Variable(torch.rand(batch_size, *shape))
        
        features = self._forward_features(inp)
        n_size = features.data.view(batch_size, -1).size(1)
        return n_size
    
    def _forward_features(self, x):
        
        x = self.feature_extractor(x)
        return x
    
    def forward(self, x):

        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = functional.sigmoid(self.classifier(x))
        
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        sigmoids = self(x)
        loss = functional.binary_cross_entropy(sigmoids, y)

        prec_rec = precision_recall(sigmoids, y.int())
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_prec', prec_rec[0], on_step=True, on_epoch=True, logger=True)        
        self.log('train_rec', prec_rec[1], on_step=True, on_epoch=True, logger=True)        

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        sigmoids = self(x)
        loss = functional.binary_cross_entropy(sigmoids, y)

        prec_rec = precision_recall(sigmoids, y.int())
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_prec', prec_rec[0], prog_bar=True)
        self.log('val_rec', prec_rec[1], prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        sigmoids = self(x)
        loss = functional.binary_cross_entropy(sigmoids, y)

        prec_rec = precision_recall(sigmoids, y.int())        
        for i in range(len(y)):
            self.predictions.append((sigmoids[i], y[i]))

        self.log('test_loss', loss, prog_bar=True)
        self.log('test_prec', prec_rec[0], prog_bar=True)
        self.log('test_rec', prec_rec[1], prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [52]:
# load model from a checkpoint

if Load_from_checkpoint:
    dm = RecognizeColorDM(setupdir='setupdir', labels_csv='777x20_color.csv',
                          train_frac=0.7, seed=0, batch_size=batch_size)
    dm.setup()
    dm.test_mode = True
    trainer = Trainer(gpus=[0])
    model = RecognizeColorModel.load_from_checkpoint(chkpt_path)
    trainer.test(model, dm)
    dm.test_mode = False

In [40]:
# train model
if Train:
    if not dm:
        dm = RecognizeColorDM(setupdir='setupdir',
                              labels_csv='777x20_color.csv',
                              train_frac=0.7, seed=0,
                              batch_size=batch_size)
        dm.setup()

    if not model:
        model = RecognizeColorModel((3,224,224), dm.num_colors,
                              batch_size=batch_size, learning_rate=2e-4)


    checkpoint_loss = ModelCheckpoint(dirpath=chkpts_upload_dir,
                                monitor='val_loss', save_top_k=1)
    checkpoint_recall = ModelCheckpoint(dirpath=chkpts_upload_dir+'by_recall',
                                monitor='val_rec', save_top_k=1, mode='max')

    trainer = Trainer(max_epochs=epochs,
                    progress_bar_refresh_rate=1,
                    gpus=[0],
                    callbacks = [checkpoint_loss,
                                 checkpoint_recall])

    trainer.fit(model, dm)

    trainer.test()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type   | Params
---------------------------------------------
0 | feature_extractor | ResNet | 21.8 M
1 | classifier        | Linear | 887 K 
---------------------------------------------
22.7 M    Trainable params
0         Non-trainable params
22.7 M    Total params
90.742    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

KeyboardInterrupt: ignored

In [53]:
if Show:
    import matplotlib.pyplot as plt
    %matplotlib inline
    from PIL import Image
    import numpy as np

    color_list = np.array(dm.test.dataset.color_list)
    print("Right Ones")
    i, count = 0, 0
    while i < len(model.predictions) and count < show_examples:
        predict = color_list[model.predictions[i][0].detach().cpu().numpy() > 0.5]
        true_label = color_list[model.predictions[i][1].detach().cpu().numpy() == 1]
        if set(predict) == set(true_label):
            count += 1
            print('predict', predict)
            print('true label', true_label)
            ind = dm.test.indices[i]
            plt.imshow(Image.open(dm.test.dataset.labels[ind][0]))
            plt.show()
            print('='*80)
        i += 1

In [54]:
if Show:
    print("Wrong Ones")
    i, count = 0, 0
    while i < len(model.predictions) and count < show_examples:
        predict = color_list[model.predictions[i][0].detach().cpu().numpy() > 0.5]
        true_label = color_list[model.predictions[i][1].detach().cpu().numpy() == 1]
        if set(predict) != set(true_label):
            count += 1
            print('predict', predict)
            print('true label', true_label)
            ind = dm.test.indices[i]
            plt.imshow(Image.open(dm.test.dataset.labels[ind][0]))
            plt.show()
            print('='*80)
        i += 1