<a href="https://colab.research.google.com/github/imemmul/GenerativeNFT/blob/main/ViTfinetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from transformers import ViTForImageClassification
import pandas as pd

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
valid_list = ["azuki",
"sappy-seals",
"killabears",
"lazy-lions",
"genuine-undead",
"genesis-creepz",
"bastard-gan-punks-v2",
"pudgypenguins",
"beanzofficial",
"ninja-squad-official",
"azragames-thehopeful",
"thewarlords",
"parallel-avatars",
"pixelmongen1",
"kanpai-pandas"]

In [None]:
labels_dir = "/content/drive/MyDrive/rarity_dataset/labels.csv"

In [None]:
df = pd.read_csv(labels_dir)

In [None]:
df['label'].notna()

0        True
1        True
2        True
3        True
4        True
         ... 
22414    True
22415    True
22416    True
22417    True
22418    True
Name: label, Length: 22419, dtype: bool

In [None]:
len(df)

22419

In [None]:
def extract_rank(row):
    return row['rank'] if row and 'rank' in row else None

In [None]:
import ast

In [None]:
def convert_to_dict(string_repr):
    try:
        return ast.literal_eval(string_repr)
    except (SyntaxError, ValueError):
        return None

In [None]:
df['label'][1]

"{'strategy_id': None, 'strategy_version': None, 'rank': 2608, 'score': None, 'calculated_at': '', 'max_rank': None, 'tokens_scored': 0, 'ranking_features': None}"

In [None]:
df['dict_values'] = df['label'].apply(convert_to_dict)

In [None]:
df['rank_values'] = df['dict_values'].apply(extract_rank)

In [None]:
df['rank_values'].head()

0    7008.0
1    2608.0
2    5465.0
3    7386.0
4    5917.0
Name: rank_values, dtype: float64

In [None]:
from torch.utils.data import Dataset
import pandas as pd
from PIL import Image
import math
import os
import numpy as np

class RarityDataset(Dataset):
    def __init__(self, csv_dir, col_names, image_dir, transform):
        self.labels = pd.read_csv(csv_dir)
        self.col_names = col_names
        self.transform = transform
        self.labels['dict'] = df['label'].apply(convert_to_dict)
        self.labels['rank_values'] = self.labels["dict"].apply(extract_rank)
        self.col_max_rarity = self.calculate_rarity()
        self.drop_nan_ones()
        self.image_dir = image_dir

    def drop_nan_ones(self):
        max_col_rarity = self.col_max_rarity.copy()
        self.collection_drop = []
        for key, val in max_col_rarity.items():
            if math.isnan(val):
                print(f"{key}:{val}")
                self.col_max_rarity.pop(key)
                self.collection_drop.append(key)
        for key in self.collection_drop:
            self.labels.drop(self.labels[self.labels['data_name'].str.startswith(key)].index, inplace=True)
        self.labels.dropna(inplace=True)
        self.labels.reset_index(inplace=True)

    def __len__(self):
        return len(self.labels)

    def calculate_rarity(self):
        max_col_rarities = {}
        for col in self.col_names:
            filtered_df = self.labels[self.labels["data_name"].str.startswith(col)]
            max_col_rarities[col] = filtered_df["rank_values"].max()
        return max_col_rarities

    def __getitem__(self, index):
        col_name = self.labels['data_name'][index].split("_")[0] # bu olabilir
        img_dir = os.path.join(self.image_dir, self.labels['data_name'][index])
        img = np.array(Image.open(img_dir).convert('RGB'))
        if transform:
            img = transform(img)
        return img, self.labels['rank_values'][index] / self.col_max_rarity[col_name]

In [None]:
rarity_dataset = RarityDataset(labels_dir, valid_list, "/content/drive/MyDrive/rarity_dataset", transform=transform)

sappy-seals:nan
genesis-creepz:nan
pixelmongen1:nan


In [None]:
rarity_dataset.col_max_rarity

{'azuki': 10000.0,
 'killabears': 3333.0,
 'lazy-lions': 9997.0,
 'genuine-undead': 9983.0,
 'bastard-gan-punks-v2': 11303.0,
 'pudgypenguins': 8886.0,
 'beanzofficial': 19946.0,
 'ninja-squad-official': 8881.0,
 'azragames-thehopeful': 5541.0,
 'thewarlords': 9999.0,
 'parallel-avatars': 10998.0,
 'kanpai-pandas': 6998.0}

+++++++


In [None]:
len(rarity_dataset.labels)

17042

In [None]:
rarity_dataset.labels['data_name']

0        ninja-squad-official_8870.png
1        ninja-squad-official_8869.png
2        ninja-squad-official_8868.png
3        ninja-squad-official_8867.png
4        ninja-squad-official_8866.png
                     ...              
17037              killabears_1789.png
17038              killabears_1819.png
17039              killabears_1848.png
17040              killabears_1788.png
17041              killabears_1818.png
Name: data_name, Length: 17042, dtype: object

In [None]:
from torch.utils.data import random_split

train_size = int(0.8 * len(rarity_dataset))
test_size = len(rarity_dataset) - train_size

train_dataset, test_dataset = random_split(rarity_dataset, [train_size, test_size])

print("Training set size:", len(train_dataset))
print("Testing set size:", len(test_dataset))

Training set size: 13633
Testing set size: 3409


In [None]:
import torch

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False)

In [None]:
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k')
model.classifier = nn.Linear(model.config.hidden_size, 1)
optimizer = optim.SGD(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

In [None]:
len(test_loader) * 8

3416

In [None]:
def val(model, test_loader, device):
  model.eval()
  avg_loss = 0
  with torch.no_grad():
    for img, label in test_loader:
      img, label = img.to(device), label.to(device)
      outputs = model(img)
      loss = criterion(outputs.logits.to(torch.float64).squeeze(), label)
      avg_loss += loss.item()
      print(f"validating: loss{loss.item()}")
  return avg_loss / len(test_loader)



In [None]:
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for i, batch in enumerate(train_loader):
        images, labels = batch
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        # print(outputs)
        # print(labels)
        loss = criterion(outputs.logits.to(torch.float64).squeeze(), labels)
        loss.backward()
        optimizer.step()
        print(f"iteration {i}, loss: {loss.item()}")
        train_loss += loss.item()
    val_loss = val(model, test_loader, device)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss / len(train_loader)} val_loss: {val_loss}")



iteration 0, loss: 0.22152560223149048
iteration 1, loss: 0.4144515388211822
iteration 2, loss: 0.1497582207398216
iteration 3, loss: 0.08469826273630164
iteration 4, loss: 0.26041337938183223
iteration 5, loss: 0.15492285844843873
iteration 6, loss: 0.12130514794689982
iteration 7, loss: 0.2109514387080232
iteration 8, loss: 0.13508121544345533
iteration 9, loss: 0.18231523288584103
iteration 10, loss: 0.1500906758904546
iteration 11, loss: 0.10021295149910683
iteration 12, loss: 0.1691015983854367
iteration 13, loss: 0.16884040713145487
iteration 14, loss: 0.12651442271534136
iteration 15, loss: 0.18524898193387695
iteration 16, loss: 0.05963478252271162
iteration 17, loss: 0.19907521610785392
iteration 18, loss: 0.12975687506366956
iteration 19, loss: 0.06079509183079543
iteration 20, loss: 0.13711216920815256
iteration 21, loss: 0.0589182424627608
iteration 22, loss: 0.12520277528077958
iteration 23, loss: 0.18538394385403095
iteration 24, loss: 0.13884344929982825
iteration 25, lo



iteration 133, loss: 0.12507389206764946
iteration 134, loss: 0.11459541894887507
iteration 135, loss: 0.05866995282235481
iteration 136, loss: 0.03779255348125987
iteration 137, loss: 0.12510138309428104
iteration 138, loss: 0.08861520512536132
iteration 139, loss: 0.13862595293647934
iteration 140, loss: 0.07997693001361031
iteration 141, loss: 0.05134594911114313
iteration 142, loss: 0.11530980524662289
iteration 143, loss: 0.08410675044449238
iteration 144, loss: 0.09071420003169073
iteration 145, loss: 0.08960318686389704
iteration 146, loss: 0.11186679056679598
iteration 147, loss: 0.10630389282753766
iteration 148, loss: 0.10362307527509626
iteration 149, loss: 0.08297255973995282
iteration 150, loss: 0.056700894141914246
iteration 151, loss: 0.1329465525713114
iteration 152, loss: 0.10268157477884879
iteration 153, loss: 0.16818326790011645
iteration 154, loss: 0.07022642323449335
iteration 155, loss: 0.08644000107417864
iteration 156, loss: 0.09267006798364857
iteration 157, l