In [1]:
import sys

sys.path.append("c:\pixt")

from dataset import Pixt_Dataset, Pixt_Test_Dataset
from dataset.transform import Pixt_ImageTransform, Pixt_TextTransform, Pixt_TargetTransform
from metrics import Accuracy

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import clip
from tqdm import tqdm
from PIL import Image

from loss import MultiLabelSoftMarginLoss
from metrics import Accuracy
from module import BaselineLitModule

# train dataset class instance

In [2]:
img_dir = ".././data/"
train_annotation_dir = ".././data/annotation/annotation_merged_remove_gap/train.csv"
valid_annotation_dir = ".././data/annotation/annotation_merged_remove_gap/valid.csv"
image_transform = Pixt_ImageTransform()
train_dataset = Pixt_Dataset(img_dir, train_annotation_dir, image_transform)
valid_dataset = Pixt_Dataset(img_dir, valid_annotation_dir, image_transform)

# train dataloader class instance

In [3]:
text_transform = Pixt_TextTransform(
    max_length=300,
    classes_ko_dir="c:\\pixt\\data\\annotation\\annotation_merged_remove_gap\\all_class_list_ko.pt",
    classes_en_dir="c:\\pixt\\data\\annotation\\annotation_merged_remove_gap\\all_class_list_en.pt",
)
target_transform = Pixt_TargetTransform(max_length=300)

def collate_fn(samples):
    image_tensor = torch.stack([sample["image_tensor"] for sample in samples], dim=0)
    text_dict = text_transform([sample["text_ko"] for sample in samples])
    target_tensor = target_transform(text_dict["text_en"], text_dict["text_input"])

    input_data = text_dict
    text_dict["image_tensor"] = image_tensor
    text_dict["target_tensor"] = target_tensor
    return input_data

In [4]:
train_dataloader = DataLoader(
    dataset=train_dataset,
    shuffle=False,
    drop_last=True,
    num_workers=0,
    batch_size=1,
    persistent_workers=False,
    collate_fn=collate_fn
)
valid_dataloader = DataLoader(
    dataset=valid_dataset,
    shuffle=False,
    drop_last=True,
    num_workers=0,
    batch_size=16,
    persistent_workers=False,
    collate_fn=collate_fn
)

# train

In [32]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _ = clip.load("RN50", device=device)

# base_loss = MultiLabelSoftMarginLoss(base_loss_weight=1)
# accuracy = Accuracy()

# lit_module = BaselineLitModule(
#     clip_model=model,
#     base_loss_func=base_loss,
#     accuracy=accuracy,
#     optim=torch.optim.Adam,
#     lr=5.0e-05,
# )
# # all models RN50
# # original dataset & learning rate 1.0e-06
# ckpt_path_0 = "c:/pixt/outputs/pixt_baseline/lightning_logs/version_0/epoch=50-step=17340.ckpt"
# # remove mgf dataset & learning rate 1.0e-06
# ckpt_path_1 = "c:/pixt/outputs/pixt_baseline/lightning_logs/version_1/epoch=76-step=22715.ckpt"
# # original dataset & learning rate 5.0e-05
# ckpt_path_2 = "c:/pixt/outputs/pixt_baseline/lightning_logs/version_2/epoch=83-step=28560.ckpt"
# # remove mgf dataset & learning rate 5.0e-05
# ckpt_path_3 = "c:/pixt/outputs/pixt_baseline/lightning_logs/version_3/epoch=40-step=12095.ckpt"

# ckpt_path_4 = "c:/pixt/outputs/pixt_baseline/lightning_logs/version_4/epoch=19-step=109080.ckpt"

# lit_module.load_state_dict(torch.load(ckpt_path_4)["state_dict"])
# model = lit_module._clip_model

In [33]:
model.visual.conv1.weight.grad

In [34]:
loss_func = nn.MultiLabelSoftMarginLoss()
loss_func = nn.MSELoss()
accuracy = Accuracy()
optimizer = optim.Adam(model.parameters(),lr=1e-3,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2) 

In [35]:
sim = torch.nn.CosineSimilarity()
epochs =100
for epoch in range(epochs):
    model.train()
    for batch in tqdm(train_dataloader):
        image_tensor = batch["image_tensor"].to(device)
        text_ko = batch["text_ko"]
        text_en = batch["text_en"]
        text_input = batch["text_input"]
        text_tensor = batch["text_tensor"].to(device)
        target_tensor = batch["target_tensor"].to(device)

        image_features = model.encode_image(image_tensor)
        text_features = model.encode_text(text_tensor)
        similarity = sim(image_features, text_features)
        # image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        # text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        # similarity = image_features @ text_features.T
        similarity = similarity.float()
        target_tensor = target_tensor.float()
        loss = loss_func(similarity, target_tensor)
        # acc = accuracy(similarity, text_en, text_input)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(epoch, loss)
#        print(torch.mean(model.visual.conv1.weight.grad))
        print('similarity true', torch.mean(similarity[:14]))
        print('similarity false', torch.mean(similarity[14:28]))

        break
    # model.eval()
    # for batch in tqdm(valid_dataloader):
    #     image_tensor = batch["image_tensor"].to(device)
    #     text_ko = batch["text_ko"]
    #     text_en = batch["text_en"]
    #     text_input = batch["text_input"]
    #     text_tensor = batch["text_tensor"].to(device)
    #     target_tensor = batch["target_tensor"].to(device)

    #     image_features = model.encode_image(image_tensor)
    #     text_features = model.encode_text(text_tensor)
    #     image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    #     text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    #     similarity = image_features @ text_features.T

    #     loss = loss_func(similarity, target_tensor)
    #     acc = accuracy(similarity, text_en, text_input)
    #     break
    # break

  0%|          | 0/5454 [00:06<?, ?it/s]


0 tensor(0.0634, device='cuda:0', grad_fn=<MseLossBackward0>)
similarity true tensor(0.1932, device='cuda:0', grad_fn=<MeanBackward0>)
similarity false tensor(0.1896, device='cuda:0', grad_fn=<MeanBackward0>)


  0%|          | 0/5454 [00:00<?, ?it/s]

1 tensor(0.0467, device='cuda:0', grad_fn=<MseLossBackward0>)
similarity true tensor(0., device='cuda:0', grad_fn=<MeanBackward0>)
similarity false tensor(0., device='cuda:0', grad_fn=<MeanBackward0>)


  0%|          | 0/5454 [00:06<?, ?it/s]
  0%|          | 0/5454 [00:06<?, ?it/s]


2 tensor(0.0467, device='cuda:0', grad_fn=<MseLossBackward0>)
similarity true tensor(0., device='cuda:0', grad_fn=<MeanBackward0>)
similarity false tensor(0., device='cuda:0', grad_fn=<MeanBackward0>)


  0%|          | 0/5454 [00:06<?, ?it/s]


KeyboardInterrupt: 

In [24]:
text_tensor[0][:10] , text_tensor[2][:10]

(tensor([49406,   320,  1125,   539,   320,  1901,  3750, 49407,     0,     0],
        device='cuda:0', dtype=torch.int32),
 tensor([49406,   320,  1125,   539,   320, 11489, 49407,     0,     0,     0],
        device='cuda:0', dtype=torch.int32))

In [25]:
text_features

tensor([[-0.1621,  1.7803,  0.1146,  ...,  0.8813, -0.1995, -1.1270],
        [-0.1570,  1.6768,  0.3010,  ...,  1.7754, -0.1083, -1.1777],
        [-0.0588,  1.7988, -0.2881,  ...,  0.6582, -0.8076, -1.4912],
        ...,
        [-0.1108,  1.8359,  0.0908,  ...,  1.6758, -0.3726, -1.1572],
        [ 0.0410,  1.8467, -0.0949,  ...,  0.6787, -0.2791, -0.7520],
        [-0.1907,  2.7988, -0.0188,  ...,  1.1338,  0.1713, -1.5039]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>)

In [None]:
torch.nn.CosineSimilarity()(image_features ,  text_features)

In [None]:
similarity

In [None]:
target_tensor

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _ = clip.load("RN50", device=device)

base_loss = MultiLabelSoftMarginLoss(base_loss_weight=1)
accuracy = Accuracy()

lit_module = BaselineLitModule(
    clip_model=model,
    base_loss_func=base_loss,
    accuracy=accuracy,
    optim=torch.optim.Adam,
    lr=5.0e-05,
)
# all models RN50
# original dataset & learning rate 1.0e-06
ckpt_path_0 = "c:/pixt/outputs/pixt_baseline/lightning_logs/version_0/epoch=50-step=17340.ckpt"
# remove mgf dataset & learning rate 1.0e-06
ckpt_path_1 = "c:/pixt/outputs/pixt_baseline/lightning_logs/version_1/epoch=76-step=22715.ckpt"
# original dataset & learning rate 5.0e-05
ckpt_path_2 = "c:/pixt/outputs/pixt_baseline/lightning_logs/version_2/epoch=83-step=28560.ckpt"
# remove mgf dataset & learning rate 5.0e-05
ckpt_path_3 = "c:/pixt/outputs/pixt_baseline/lightning_logs/version_3/epoch=40-step=12095.ckpt"

ckpt_path_4 = "c:/pixt/outputs/pixt_baseline/lightning_logs/version_4/epoch=19-step=109080.ckpt"

lit_module.load_state_dict(torch.load(ckpt_path_3)["state_dict"])
model = lit_module._clip_model
model

# test

In [None]:
img_dir = ".././data/"
image_transform = Pixt_ImageTransform()
test_dataset = Pixt_Test_Dataset(img_dir, image_transform)

In [None]:
test_dataloader = DataLoader(
    dataset=test_dataset,
    shuffle=False,
    drop_last=False,
    num_workers=0,
    batch_size=1,
    persistent_workers=False,
)
image_tensor = None
for batch in test_dataloader:
    image_filename = batch["image_filename"]
    image_tensor = batch["image_tensor"]
    print(image_filename, image_tensor.shape)
    break

In [None]:
with torch.no_grad():
    image_features = model.encode_image(image_tensor)
    text_features = model.encode_text(text_tensor)
image_features.shape, text_features.shape

In [None]:
classes_list = torch.load(".././data/annotation/all_class_list_en.pt")
classes_list = [tag_ko.lower() for tag_ko in classes_list]
classes_list = sorted(set(classes_list))
text_input = torch.cat([clip.tokenize(f"a photo of a {c}") for c in classes_list]).to(device)

image_number = 1
file_path = ".././data/dataset3/"+ str(image_number) + ".webp"
print(file_path)
Image.open(file_path).show()

Image_transform = Pixt_ImageTransform()
image_input = Image_transform(Image.open(file_path).convert("RGB")).float().unsqueeze(0).to(device)
print(image_input.shape, text_input.shape)

with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_input)
image_features.shape, text_features.shape

In [None]:
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
image_features.shape, text_features.shape

In [None]:
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(14)

classes_ko_list = torch.load(".././data/annotation/all_class_list_ko.pt")
classes_en_list = torch.load(".././data/annotation/all_class_list_en.pt")
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    en_word = text_input[index]
    print(f"{en_word:>16s}: {100 * value.item():.100f}%")
    # print(classes_ko_list[classes_en_list.index(en_word)])

In [None]:
(100.0 * image_features @ text_features.T)

In [None]:

indices