<a href="https://colab.research.google.com/github/imemmul/GenerativeNFT/blob/rarity/ViT/ViTfinetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [138]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from transformers import ViTForImageClassification

In [199]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((512, 512)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [140]:
# from google.colab import drive
# drive.mount('/content/drive')

In [141]:
valid_list = ["azuki",
"sappy-seals",
"killabears",
"lazy-lions",
"genuine-undead",
"genesis-creepz",
"bastard-gan-punks-v2",
"pudgypenguins",
"beanzofficial",
"ninja-squad-official",
"azragames-thehopeful",
"thewarlords",
"parallel-avatars",
"pixelmongen1",
"kanpai-pandas"]

In [142]:
labels_dir = "/Users/emirulurak/Desktop/dev/ozu/openseadata/dataset/rarity_dataset/labels.csv"

In [143]:
df = pd.read_csv(labels_dir)

In [144]:
df['label'].notna()

0        True
1        True
2        True
3        True
4        True
         ... 
22414    True
22415    True
22416    True
22417    True
22418    True
Name: label, Length: 22419, dtype: bool

In [145]:
def extract_rank(row):
    return row['rank'] if row and 'rank' in row else None

In [146]:
import ast

In [147]:
def convert_to_dict(string_repr):
    try:
        return ast.literal_eval(string_repr)
    except (SyntaxError, ValueError):
        return None

In [148]:
df['label'][1]

"{'strategy_id': None, 'strategy_version': None, 'rank': 2608, 'score': None, 'calculated_at': '', 'max_rank': None, 'tokens_scored': 0, 'ranking_features': None}"

In [149]:
df['dict_values'] = df['label'].apply(convert_to_dict)

In [150]:
df['rank_values'] = df['dict_values'].apply(extract_rank)

In [256]:
from torch.utils.data import Dataset
import pandas as pd
from PIL import Image
import math
import os
import numpy as np

class RarityDataset(Dataset):
    def __init__(self, csv_dir, col_names, image_dir, transform):
        self.labels = pd.read_csv(csv_dir)
        self.col_names = col_names
        self.transform = transform
        self.labels['dict'] = df['label'].apply(convert_to_dict)
        self.labels['rank_values'] = self.labels["dict"].apply(extract_rank)
        self.col_max_rarity = self.calculate_rarity()
        self.drop_nan_ones()
        self.image_dir = image_dir
    
    def drop_nan_ones(self):
        max_col_rarity = self.col_max_rarity.copy()
        self.collection_drop = []
        for key, val in max_col_rarity.items():
            if math.isnan(val):
                print(f"{key}:{val}")
                self.col_max_rarity.pop(key)
                self.collection_drop.append(key)
        for key in self.collection_drop:
            self.labels.drop(self.labels[self.labels['data_name'].str.startswith(key)].index, inplace=True)
        self.labels.reset_index()
    
    def __len__(self):
        return len(self.labels)
    
    def calculate_rarity(self):
        max_col_rarities = {}
        for col in self.col_names:
            filtered_df = self.labels[self.labels["data_name"].str.startswith(col)]
            max_col_rarities[col] = filtered_df["rank_values"].max()
        return max_col_rarities
    
    def __getitem__(self, index):
        col_name = self.labels['data_name'][index].split("_")[0]
        print(col_name)
        img_dir = os.path.join(self.image_dir, self.labels['data_name'][index])
        img = np.array(Image.open(img_dir).convert('RGB'))
        if transform:
            img = transform(img)
        return img, self.labels['rank_values'][index] / self.col_max_rarity[col_name]

In [257]:
rarity_dataset = RarityDataset(labels_dir, valid_list, "/Users/emirulurak/Desktop/dev/ozu/openseadata/dataset/rarity_dataset", transform=transform)

sappy-seals:nan
genesis-creepz:nan
pixelmongen1:nan


In [258]:
rarity_dataset.col_max_rarity

{'azuki': 10000.0,
 'killabears': 3333.0,
 'lazy-lions': 9997.0,
 'genuine-undead': 9983.0,
 'bastard-gan-punks-v2': 11303.0,
 'pudgypenguins': 8886.0,
 'beanzofficial': 19946.0,
 'ninja-squad-official': 8881.0,
 'azragames-thehopeful': 5541.0,
 'thewarlords': 9999.0,
 'parallel-avatars': 10998.0,
 'kanpai-pandas': 6998.0}

+++++++


In [272]:
rarity_dataset.labels['data_name']

0        ninja-squad-official_8870.png
1        ninja-squad-official_8869.png
2        ninja-squad-official_8868.png
3        ninja-squad-official_8867.png
4        ninja-squad-official_8866.png
                     ...              
22414              killabears_1789.png
22415              killabears_1819.png
22416              killabears_1848.png
22417              killabears_1788.png
22418              killabears_1818.png
Name: data_name, Length: 17920, dtype: object

In [260]:
img, label = rarity_dataset[0]

ninja-squad-official




In [261]:
from torch.utils.data import random_split

train_size = int(0.8 * len(rarity_dataset))
test_size = len(rarity_dataset) - train_size

train_dataset, test_dataset = random_split(rarity_dataset, [train_size, test_size])

print("Training set size:", len(train_dataset))
print("Testing set size:", len(test_dataset))

Training set size: 14336
Testing set size: 3584


In [262]:
import torch

In [263]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)

In [264]:
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k')
model.classifier = nn.Linear(model.config.hidden_size, 2)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [265]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

In [266]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

In [267]:
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs.logits, labels)
        loss.backward()
        optimizer.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")

KeyError: 17448