In [1]:
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm import tqdm
import albumentations as A
import cv2
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader
from computer_vision.self_supervised.dino.func import ImageData, CollateFn, CollateSingleImage, ImageOriginalData, Model, clip_loss, LightningModel, Config
from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import LearningRateMonitor

seed_everything(42)

Global seed set to 42


42

In [2]:
path = Path.home() / 'OneDrive - Seagroup/ai/computer_vison/shopee_price_matching/shopee_ds/images'
files = [str(file) for file in path.glob("*.jpg")]

train_files, valid_files = train_test_split(files, test_size=0.15, random_state=42)

train_data = ImageData(train_files)
train_dl = DataLoader(
    train_data,
    Config.batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=Config.num_workers,
    pin_memory=True,
    collate_fn=CollateFn(),
)

valid_data = ImageOriginalData(valid_files)
valid_dl = DataLoader(
    valid_data,
    Config.batch_size*2,
    shuffle=False,
    drop_last=False,
    num_workers=Config.num_workers,
    pin_memory=True,
    collate_fn=CollateSingleImage(),
)

In [3]:
train_data[0]

(tensor([[[1.3927, 1.3927, 1.3927,  ..., 1.6667, 1.6667, 1.6838],
          [1.3927, 1.3927, 1.3927,  ..., 1.6838, 1.6838, 1.7009],
          [1.3927, 1.3927, 1.3927,  ..., 1.6667, 1.6667, 1.6667],
          ...,
          [1.0844, 1.0844, 1.0844,  ..., 1.3755, 1.1015, 0.7419],
          [1.0844, 1.0844, 1.0673,  ..., 0.9646, 0.9988, 1.0502],
          [1.0673, 1.0673, 1.0502,  ..., 0.8276, 1.2899, 0.6734]],
 
         [[1.5357, 1.5357, 1.5357,  ..., 1.8158, 1.8158, 1.8333],
          [1.5357, 1.5357, 1.5357,  ..., 1.8333, 1.8333, 1.8508],
          [1.5357, 1.5357, 1.5357,  ..., 1.8158, 1.8158, 1.8158],
          ...,
          [1.2206, 1.2206, 1.2206,  ..., 1.5007, 1.2206, 0.8529],
          [1.2206, 1.2206, 1.2031,  ..., 1.0980, 1.1155, 1.1681],
          [1.2206, 1.2206, 1.1856,  ..., 0.9580, 1.3957, 0.7304]],
 
         [[1.8383, 1.8383, 1.8383,  ..., 2.1171, 2.1171, 2.1520],
          [1.8383, 1.8383, 1.8383,  ..., 2.1520, 2.1346, 2.1520],
          [1.8383, 1.8383, 1.8383,  ...,

In [4]:
valid_data[0]

tensor([[[1.4098, 1.3755, 1.4440,  ..., 1.0502, 0.9646, 1.1700],
         [1.2728, 1.3927, 1.4098,  ..., 1.1872, 0.9817, 1.1015],
         [1.3584, 1.3755, 1.4098,  ..., 1.2214, 0.9988, 1.0673],
         ...,
         [1.3755, 1.4612, 1.3070,  ..., 1.0159, 1.2899, 1.2385],
         [1.4440, 1.3755, 1.2899,  ..., 1.2214, 1.2728, 1.2385],
         [1.4612, 1.3584, 1.3413,  ..., 1.1700, 1.3242, 1.2214]],

        [[1.5707, 1.5357, 1.6057,  ..., 1.1856, 1.1155, 1.3081],
         [1.4307, 1.5532, 1.5707,  ..., 1.3256, 1.1155, 1.2381],
         [1.5182, 1.5357, 1.5707,  ..., 1.3606, 1.1331, 1.2031],
         ...,
         [1.5182, 1.6057, 1.4482,  ..., 1.1856, 1.4657, 1.4132],
         [1.5882, 1.5182, 1.4307,  ..., 1.3957, 1.4482, 1.4132],
         [1.6057, 1.5007, 1.4832,  ..., 1.3431, 1.5007, 1.3957]],

        [[1.8208, 1.7860, 1.8557,  ..., 1.4897, 1.4025, 1.6117],
         [1.6814, 1.7860, 1.8208,  ..., 1.6291, 1.4200, 1.5420],
         [1.7511, 1.7685, 1.8034,  ..., 1.6640, 1.4374, 1.

In [None]:
lr_monitor = LearningRateMonitor(logging_interval='step')
teacher = Model()
epochs = 2
lightning_model = LightningModel(
    model=teacher,
    learning_rate=1e-6,
    loss_fn=clip_loss,
    valid_files=valid_files,
    max_epochs=epochs,
    weight_decay=0.1,
)

trainer = Trainer(
    accelerator='gpu',
    max_epochs=epochs,
    precision=16,
    deterministic=True,
    callbacks=[lr_monitor],
)
trainer.fit(lightning_model, train_dl, valid_dl)

Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type  | Params
--------------------------------
0 | model | Model | 34.5 M
--------------------------------
34.5 M    Trainable params
0         Non-trainable params
34.5 M    Total params
68.926    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
image_orig_data = ImageOriginalData(valid_files)
image_orig_dl = DataLoader(
    image_orig_data,
    Config.batch_size*2,
    shuffle=False,
    drop_last=False,
    num_workers=Config.num_workers,
    pin_memory=True,
    collate_fn=CollateSingleImage(),
)

device = 'cuda'
teacher = teacher.eval().to(device)
embedding = []
with torch.no_grad():
    for x in tqdm(image_orig_dl):
        out = teacher(x.to(device))
        embedding.append(out.cpu())
    embedding = torch.cat(embedding, dim=0)


  0%|          | 0/235 [00:00<?, ?it/s][A

In [None]:
resize = A.Resize(256, 256)


def get_closest(embedding: torch.FloatTensor, i: int):
    similarity = embedding @ embedding[i,:].T
    scores, idx = similarity.topk(5)
    return scores, idx

def read_image(file):
    image = cv2.imread(file)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

def plot_closest_pairs(embedding, i, files):
    img = resize(image=read_image(files[i]))
    plt.imshow(img['image'])
    scores, idx = get_closest(embedding, i)
    
    fig, axes = plt.subplots(1, len(idx), figsize=(12, 5))
    for i, score, ax in zip(idx, scores, axes):
        img = resize(image=read_image(files[i]))
        ax.imshow(img['image'])
        ax.set_title(f"Score: {score:.2f}")
        ax.axis('off')

    plt.show()
    
    
i = 1
plot_closest_pairs(embedding, i, valid_files)