In [1]:
import os
import math
import cv2
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import linalg
import torch.utils.data as data
import open_clip
import pytorch_lightning as pl

def read_image(image_file):
    img = cv2.imread(
        image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION
    )
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    if img is None:
        raise ValueError('Failed to read {}'.format(image_file))
    return img


class Product10KDataset(data.Dataset):
    def __init__(self, root, annotation_file, transforms, is_inference=False,
                 with_bbox=False):
        self.root = root
        self.imlist = pd.read_csv(annotation_file)
        self.transforms = transforms
        self.is_inference = is_inference
        self.with_bbox = with_bbox

    def __getitem__(self, index):
        cv2.setNumThreads(6)

        if self.is_inference:
            impath, _, _ = self.imlist.iloc[index]
        else:
            impath, target, _ = self.imlist.iloc[index]

        full_imname = os.path.join(self.root, impath)
        img = read_image(full_imname)

        if self.with_bbox:
            x, y, w, h = self.table.loc[index, 'bbox_x':'bbox_h']
            img = img[y:y+h, x:x+w, :]

        img = Image.fromarray(img)
        img = self.transforms(img)

        if self.is_inference:
            return img
        else:
            return img, target

    def __len__(self):
        return len(self.imlist)


class SubmissionDataset(data.Dataset):
    def __init__(self, root, annotation_file, transforms, with_bbox=False):
        self.root = root
        self.imlist = pd.read_csv(annotation_file)
        self.transforms = transforms
        self.with_bbox = with_bbox

    def __getitem__(self, index):
        cv2.setNumThreads(6)

        full_imname = os.path.join(self.root, self.imlist['img_path'][index])
        img = read_image(full_imname)

        if self.with_bbox:
            x, y, w, h = self.imlist.loc[index, 'bbox_x':'bbox_h']
            img = img[y:y+h, x:x+w, :]

        img = Image.fromarray(img)
        img = self.transforms(img)
        return img

    def __len__(self):
        return len(self.imlist)


In [2]:
import torchvision as tv

def get_train_aug():
    train_augs = tv.transforms.Compose([
        tv.transforms.RandomResizedCrop((224,224)),
        tv.transforms.RandomHorizontalFlip(),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                std=[0.5, 0.5, 0.5])
    ])
    return train_augs

def get_val_aug():
    val_augs = tv.transforms.Compose([
        tv.transforms.Resize((224,224)),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                std=[0.5, 0.5, 0.5])
    ])
    return val_augs


In [3]:
def get_dataloaders():
    """
    Function for creating training and validation dataloaders
    :param config:
    :return:
    """
    print("Preparing train reader...")
    train_dataset = Product10KDataset(
        root='/workspace/unni/data/JD_Products_10K/product_10k/train', annotation_file='/workspace/unni/data/JD_Products_10K/product_10k/train.csv',
        transforms=get_train_aug()
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=128,
        shuffle=True,
        num_workers=24,
        pin_memory=True,
        drop_last=True
    )
    print("Done.")

    print("Preparing valid reader...")
    val_dataset = Product10KDataset(
        root='/workspace/unni/data/JD_Products_10K/product_10k/test', annotation_file='/workspace/unni/data/JD_Products_10K/product_10k/test_kaggletest.csv',
        transforms=get_val_aug()
    )
    valid_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=128,
        shuffle=False,
        num_workers=24,
        drop_last=False,
        pin_memory=True
    )
    print("Done.")

    return train_loader, valid_loader

In [4]:
class ArcFace(nn.Module):
    def __init__(self, cin, cout, s=30, m=0.3):
        super().__init__()
        self.s = s
        self.sin_m = torch.sin(torch.tensor(m))
        self.cos_m = torch.cos(torch.tensor(m))
        self.cout = cout
        self.fc = nn.Linear(cin, cout, bias=False)
    def forward(self, x, label=None):
        w_L2 = linalg.norm(self.fc.weight.detach(), dim=1, keepdim=True).T
        x_L2 = linalg.norm(x, dim=1, keepdim=True)
        cos = self.fc(x) / (x_L2 * w_L2)
        if label is not None:
            sin_m, cos_m = self.sin_m, self.cos_m
            one_hot = F.one_hot(label, num_classes=self.cout)
            sin = (1 - cos ** 2) ** 0.5
            angle_sum = cos * cos_m - sin * sin_m
            cos = angle_sum * one_hot + cos * (1 - one_hot)
            cos = cos * self.s
        return cos

class classifier_model(nn.Module):
    def __init__(self):
        super(classifier_model, self).__init__()
        self.model = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')[0].visual
        self.fc = ArcFace(768, 9691)
    def forward(self, x, labels=None):
        x = self.model(x)
        x = self.fc(x, labels)
        return x

In [8]:
class VPRModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
        # self.save_hyperparameters()
        # Create model
        self.model = classifier_model()
        # Create loss module
        self.loss_module = nn.CrossEntropyLoss()
        # Example input for visualizing the graph in Tensorboard
#         self.sample_input = torch.zeros((1, 3, 224, 224), dtype=torch.float32)
#         self.example_input_array = [self.sample_input, self.sample_input]

    def forward(self, img, labels):
        # Forward function that is run when visualizing the graph
        return self.model(img, labels)

    def configure_optimizers(self):
        # We will support Adam or SGD as optimizers.
        optimizer = torch.optim.AdamW([{"params": self.model.model.parameters(), "lr": 1e-6}, {"params": self.model.fc.parameters(), "lr": 1e-4}], weight_decay=1e-5)
        # We will reduce the learning rate by 0.1 after 20 and 24 epochs
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 24], gamma=0.1)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        # "batch" is the output of the training data loader.
        img, labels = batch
        preds = self.model(img, labels)
        loss = self.loss_module(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()
        # Logs the accuracy per epoch to tensorboard (weighted average over batches)
        self.log("train_acc", acc, on_step=True, on_epoch=True)
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        return loss  # Return tensor to call ".backward" on

    def validation_step(self, batch, batch_idx):
        img, labels = batch
        preds = self.model(img)
        loss = self.loss_module(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()
        # By default logs it per epoch (weighted average over batches)
        self.log("val_acc", acc, on_step=True, on_epoch=True)
        self.log("val_loss", loss, on_step=True, on_epoch=True)

In [9]:
train_loader, val_loader = get_dataloaders()

Preparing train reader...
Done.
Preparing valid reader...
Done.


In [10]:
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(dirpath="/workspace/unni/code/simple_fine_tune/model_saves", save_top_k=10, monitor="val_acc")
model = VPRModule()
trainer = pl.Trainer(max_epochs=25, accelerator='gpu')
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/6
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/6
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/6
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/6
Initializing distributed: GLOBAL_RANK: 4, MEMBER: 5/6
Initializing distributed: GLOBAL_RANK: 5, MEMBER: 6/6
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 6 processes
----------------------------------------------------------------------------------------------------

Missing logger folder: /workspace/unni/code/simple_fine_tune/lightning_logs
Missing logger folder: /workspace/unni/code/simple_fine_tune/lightning_logs
Missing logger folder: /workspace/unni/code/simple_fine_tune/lightning_logs
Miss

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=25` reached.
