In [1]:
import torch
import time
import torch.nn as nn
import unicodedata
import torch.optim as optim
from torch.utils.data import random_split, DataLoader
import ast
from torch.utils.data import Dataset
import pandas as pd
from PIL import Image
import math
import os
import numpy as np
from torchvision import datasets, transforms
from transformers import ViTForImageClassification
import pandas as pd
import argparse
from sklearn.metrics import accuracy_score, f1_score, precision_score, confusion_matrix
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from IPython.display import display
from PIL import Image
import numpy as np
import cv2

def fix_dir(image_dir):
    """
    this function takes a image dir and fixs it so that it exist in dataset
    """
    norms = ['NFC', 'NFD', 'NFKC', 'NFKD']
    if not os.path.isfile(image_dir):
        for norm in norms:
            print(f"norming: {norm}")
            img_dir_normalized = unicodedata.normalize(norm, image_dir)
            img_dir_normalized = img_dir_normalized.replace("'", "_")
            if os.path.isfile(img_dir_normalized):
                return img_dir_normalized
    return image_dir


def extract_rank(row):
    return row['rank'] if row and 'rank' in row else None

def convert_to_dict(string_repr):
    try:
        return ast.literal_eval(string_repr)
    except (SyntaxError, ValueError):
        return None
    

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomHorizontalFlip(),
    transforms.RandomErasing(),
    transforms.ColorJitter()
])


test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
])


class RarityDatasetMerged(Dataset):
    def __init__(self, label_dir_old, label_dir_new, image_dir_old, image_dir_new, transform):
        self.images_dir_old = image_dir_old
        self.images_dir_new = image_dir_new
        self.labels_old = pd.read_csv(label_dir_old)
        self.labels_new = pd.read_csv(label_dir_new)
        self.transform = transform
        
    def __len__(self):
        return len(self.labels_old) + len(self.labels_new)


    def __getitem__(self, index):
        try:
            if index >= len(self.labels_old):
                index = index - len(self.labels_old)
                img_dir = os.path.join(self.images_dir_new, self.labels_new.iloc[index].data_name)
                img_dir_fixed = fix_dir(img_dir)
                img = np.array(Image.open(img_dir_fixed).convert('RGB'))
                if self.transform:
                    img = self.transform(img)
                return img, self.labels_new.iloc[index].cls
            else:
                img_dir = os.path.join(self.images_dir_old, self.labels_old.iloc[index].data_name + ".png")
                img_dir_fixed = fix_dir(img_dir)
                img = np.array(Image.open(img_dir_fixed).convert('RGB'))
                if self.transform:
                    img = self.transform(img)
                return img, self.labels_old.iloc[index].cls_label
                    
        except Exception as e:
            raise RuntimeError(f"Error loading image at index {index}: {str(e)}")


def _arg_parse():
    args = argparse.ArgumentParser()
    args.add_argument("--images_dir_old", type=str)
    args.add_argument("--images_dir_new", type=str)
    args.add_argument("--merge", action="store_true")
    args.add_argument("--split_ratio", type=float)
    args.add_argument("--label_dir_old", type=str)
    args.add_argument("--label_dir_new", type=str)
    args.add_argument("--num_epochs", type=int)
    args.add_argument("--train", action="store_true")
    args.add_argument("--checkpoint", type=str)
    return args.parse_args()

def val(model, test_loader, device, criterion):
    model.eval()
    avg_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            img, labels = batch
            img, labels = img.to(device), labels.to(device)
            outputs = model(img)
            labels = labels.unsqueeze(1).to(outputs.logits.dtype)
            probs = torch.sigmoid(outputs.logits)
            loss = criterion(probs, labels)
            avg_loss += loss.item()
            print(f"{i} iteration, Validating: loss {loss.item()}")
    return avg_loss / len(test_loader)


def _train():
    # args = _arg_parse()
    args = {}
    args['label_dir_old'] = "/home/emir/Desktop/dev/datasets/nft_dataset/labels_augmented_old_collection.csv"
    args['label_dir_new'] = "/home/emir/Desktop/dev/datasets/nft_dataset/labels_augmented.csv"
    args['images_dir_new'] = "/home/emir/Desktop/dev/datasets/nft_dataset/NFT_DATASET_MERGED/new_collection"
    args['images_dir_old'] = "/home/emir/Desktop/dev/datasets/nft_dataset/NFT_DATASET_MERGED/old_collection"
    args['train'] = False
    args['split_ratio'] = 0.85
    args['checkpoint'] = "/home/emir/Desktop/dev/datasets/weights/run_01_merged_clsf.pt"
    if args['train']:
        dataset = RarityDatasetMerged(label_dir_old=args['label_dir_old'], label_dir_new=args['label_dir_new'],
                                      image_dir_old=args['images_dir_old'], image_dir_new=args['images_dir_new'], transform=transform)
    else:
        dataset = RarityDatasetMerged(label_dir_old=args['label_dir_old'], label_dir_new=args['label_dir_new'],
                                      image_dir_old=args['images_dir_old'], image_dir_new=args['images_dir_new'], transform=test_transform)
    model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k')
    model.classifier = nn.Linear(model.config.hidden_size, 1) # TODO
    if args['checkpoint']:
        model.to("cuda")
        model.load_state_dict(torch.load(args['checkpoint']))
        model.eval()
    train_dataset, val_dataset = random_split(dataset, [int(len(dataset)*args['split_ratio']), int(len(dataset) - int(len(dataset)*args['split_ratio']))])
    print(len(train_dataset))
    print(len(val_dataset))
    if args['train']:
        print("Training set size:", len(train_dataset))
        train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
        print("Val set size:", len(val_dataset))
    else:
        print("Test set size:", len(val_dataset)) 
        val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
    count = 0
    for _ in range(2000):
        model.eval()
        with torch.no_grad():
            rnd_i = np.random.randint(len(val_dataset))
            img, label = val_dataset[rnd_i]
            output = model(img.unsqueeze(0).to("cuda"))
            label = "Rare" if label else "Common"
            predicted_ = "Rare" if torch.sigmoid(output.logits).detach().cpu().numpy() > 0.5 else "Common"
            img_numpy = (img.numpy() * 255).astype(np.uint8)  # Convert tensor to numpy array with values in [0, 255]
            img_transposed = np.transpose(img_numpy, (1, 2, 0))  # Transpose dimensions to (224, 224, 3)
            if predicted_ == label:
                count += 1
            # img_pil = Image.fromarray(img_transposed)  # Convert numpy array to PIL Image
            # display(img_pil)
            print(f"Predicted: {predicted_}, True Label: {label}")
    return count

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
count = _train()
print(count)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


90322
15940
Test set size: 15940
Predicted: Rare, True Label: Rare
Predicted: Rare, True Label: Rare




Predicted: Common, True Label: Common
Predicted: Common, True Label: Common
Predicted: Rare, True Label: Rare
Predicted: Rare, True Label: Rare
Predicted: Rare, True Label: Rare
Predicted: Common, True Label: Common
Predicted: Rare, True Label: Rare
Predicted: Rare, True Label: Rare
Predicted: Common, True Label: Common
Predicted: Common, True Label: Common
Predicted: Rare, True Label: Rare
Predicted: Rare, True Label: Rare
Predicted: Rare, True Label: Rare
Predicted: Common, True Label: Rare
Predicted: Rare, True Label: Rare
Predicted: Rare, True Label: Common
Predicted: Rare, True Label: Rare
Predicted: Common, True Label: Common




Predicted: Common, True Label: Common
Predicted: Rare, True Label: Rare
Predicted: Common, True Label: Common
Predicted: Rare, True Label: Rare
Predicted: Rare, True Label: Rare
Predicted: Common, True Label: Common
Predicted: Common, True Label: Common
Predicted: Common, True Label: Rare
Predicted: Common, True Label: Common
Predicted: Common, True Label: Common
Predicted: Rare, True Label: Rare
Predicted: Common, True Label: Rare
Predicted: Common, True Label: Common
Predicted: Rare, True Label: Rare
Predicted: Common, True Label: Common
Predicted: Rare, True Label: Rare
Predicted: Common, True Label: Rare
Predicted: Rare, True Label: Rare
Predicted: Common, True Label: Common
Predicted: Rare, True Label: Rare
Predicted: Common, True Label: Common
Predicted: Common, True Label: Common
Predicted: Rare, True Label: Rare
Predicted: Common, True Label: Common
Predicted: Common, True Label: Common
Predicted: Common, True Label: Common
Predicted: Rare, True Label: Rare
Predicted: Common, T