In [1]:
import pandas as pd
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
import numpy as np
import string
import spacy
from transformers import DistilBertTokenizer, DistilBertModel
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO
from tqdm import tqdm
import concurrent.futures
import time
import random
import ast
import matplotlib.pyplot as plt
import cupy as cp
from torch.cuda.amp import autocast, GradScaler
import re
import os
import zipfile

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

le = LabelEncoder()

scaler = MinMaxScaler()

spacy.prefer_gpu() 

use_gpu = True

nlp = spacy.load('en_core_web_sm', disable=["ner", "parser", "tagger"])

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
model.to(device)

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): DistilBertSdpaAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): L

In [2]:
final_cards = pd.read_csv("preprocessed_data.csv")

In [3]:
def flat_array_to_image(flattened_array, height, width, use_gpu=True):
    if use_gpu:
        reshaped_array = cp.array(flattened_array).reshape((height, width, 3)).astype(cp.uint8)
        reshaped_array = cp.asnumpy(reshaped_array)
    else:
        reshaped_array = np.array(flattened_array).reshape((height, width, 3)).astype(np.uint8)

    image = Image.fromarray(reshaped_array)
    return image

def sanitize_flattened_array_string(array_string):
    sanitized_string = array_string.replace(' ', ' ')
    return sanitized_string
height, width = 64, 64
use_gpu = False

cache_path = 'processed_large.pt'

if os.path.exists(cache_path):
    # Load the cached data
    print("Loading cached images...")
    processed_images = torch.load(cache_path)
else:
    print("Processing images...")
    processed_images = [
        flat_array_to_image(np.array(ast.literal_eval(sanitize_flattened_array_string(x))), height, width, use_gpu) 
        for x in tqdm(final_cards['large_64x64'], desc='Processing large_64x64 images', position=0)
    ]
    torch.save(processed_images, cache_path)

final_cards['large_64x64'] = processed_images

print(f"Processed {len(processed_images)} images.")

Loading cached images...


  processed_images = torch.load(cache_path)


Processed 12369 images.


cache_path = 'processed_pokemon.pt'

if os.path.exists(cache_path):
    # Load the cached data
    print("Loading cached images...")
    processed_images = torch.load(cache_path)
else:
    print("Processing images...")
    # Process and save the data
    processed_images = [
        flat_array_to_image(np.array(ast.literal_eval(sanitize_flattened_array_string(x))), height, width, use_gpu) 
        for x in tqdm(final_cards['pokemon_64x64'], desc='Processing pokemon_64x64 images', position=0)
    ]
    torch.save(processed_images, cache_path)

final_cards = pd.read_parquet('file.parquet')

In [4]:
final_cards.head()

Unnamed: 0.1,Unnamed: 0,id,name,supertype,subtypes,hp,types,artist,rarity,set_name,large_64x64,pokemon_64x64,caption_embeddings,pokemon_intro_embeddings
0,0,0,22,0,[17],0.16129,[9],85,7,8,<PIL.Image.Image image mode=RGB size=64x64 at ...,"[255, 255, 255, 255, 255, 255, 255, 255, 255, ...",[-3.96634698e-01 -1.86856285e-01 2.77578920e-...,[-5.26714206e-01 -4.67754155e-01 2.97648579e-...
1,1,11,178,0,[17],0.225806,[10],85,7,8,<PIL.Image.Image image mode=RGB size=64x64 at ...,"[255, 255, 255, 255, 255, 255, 255, 255, 255, ...",[-3.51396680e-01 -1.25379905e-01 3.18560898e-...,[-3.65637928e-01 -8.96108747e-02 1.75579965e-...
2,2,22,285,0,[1],0.290323,[0],85,7,8,<PIL.Image.Image image mode=RGB size=64x64 at ...,"[255, 255, 255, 255, 255, 255, 255, 255, 255, ...",[-4.59370106e-01 -1.54282287e-01 2.86983043e-...,[-4.82110918e-01 -7.08584368e-01 1.12857044e-...
3,3,33,287,0,[17],0.290323,[5],118,7,8,<PIL.Image.Image image mode=RGB size=64x64 at ...,"[255, 255, 255, 255, 255, 255, 255, 255, 255, ...",[-4.22779500e-01 -1.33981556e-01 2.65981674e-...,[-3.93124610e-01 -1.14848644e-01 2.64526993e-...
4,4,44,331,0,[1],0.032258,[0],85,7,8,<PIL.Image.Image image mode=RGB size=64x64 at ...,"[255, 255, 255, 255, 255, 255, 255, 255, 255, ...",[-4.79958475e-01 -1.35355964e-01 3.01812291e-...,[-3.27733248e-01 -4.54510987e-01 2.78392106e-...


In [5]:
def clean_embedding_string(embedding_str):
    cleaned_str = re.sub(r'\s+', ',', embedding_str.strip())
    return cleaned_str

class FinalCardsDataset(Dataset):
    def __init__(self, dataframe, cache_dir="cache/"):
        self.data = dataframe

        self.image_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((64, 64)),
        ])
        
        self.cache_dir = cache_dir
        
        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir)

        self.preprocessed_data = self._load_or_cache_data()

    def _load_or_cache_data(self):
        cached_data_path = os.path.join(self.cache_dir, "preprocessed_data.pt")
        if os.path.exists(cached_data_path):
            print("Loading cached data.")
            return torch.load(cached_data_path)
        else:
            print("Preprocessing data...")
            preprocessed_data = self._preprocess_data()
            torch.save(preprocessed_data, cached_data_path)
            return preprocessed_data
    
    def _preprocess_data(self):
        preprocessed_data = []
        for idx, row in self.data.iterrows():
            try:

                types = torch.tensor(ast.literal_eval(row['types']), dtype=torch.long)
                subtypes = torch.tensor(ast.literal_eval(row['subtypes']), dtype=torch.long)

                max_length = max(len(types), len(subtypes))
                types = F.pad(types, (0, max_length - len(types)), value=0)
                subtypes = F.pad(subtypes, (0, max_length - len(subtypes)), value=0)
                
                numerical_data = {
                    'id': torch.tensor(row['id'], dtype=torch.long),
                    'name': torch.tensor(row['name'], dtype=torch.long),
                    'supertype': torch.tensor(row['supertype'], dtype=torch.long),
                    'artist': torch.tensor(row['artist'], dtype=torch.long),
                    'rarity': torch.tensor(row['rarity'], dtype=torch.long),
                    'set_name': torch.tensor(row['set_name'], dtype=torch.long),
                    'hp': torch.tensor(row['hp'], dtype=torch.float),
                }
                
                large_64x64 = self.image_transform(row['large_64x64'])
                
                preprocessed_data.append({
                    'numerical_data': numerical_data,
                    'types': types,
                    'subtypes': subtypes,
                    'large_64x64': large_64x64
                })
            
            except Exception as e:
                print(f"Error processing row {idx}: {e}")
                continue

        print(f"Total processed data: {len(preprocessed_data)}")
        return preprocessed_data
    
    def __len__(self):
        return len(self.preprocessed_data)
    
    def __getitem__(self, idx):
        return self.preprocessed_data[idx]

class VisualTransformersWithAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, num_layers, output_dim, image_size):
        super(VisualTransformersWithAttention, self).__init__()
        self.embed_dim = embed_dim
        self.image_size = image_size
        
        self.image_embed = nn.Conv2d(3, embed_dim, kernel_size=1)
        
        self.numerical_embed = nn.Linear(6, embed_dim)
        
        self.encoder_layers = nn.ModuleList([
            nn.MultiheadAttention(embed_dim, num_heads) for _ in range(num_layers)
        ])
        
        self.hp_output = nn.Linear(embed_dim, 1) 
        self.attack_output = nn.Linear(embed_dim, 1) 
    
    def forward(self, image, numerical_data):
        batch_size, _, height, width = image.size()
        image_features = self.image_embed(image)
        image_features = image_features.flatten(2).permute(2, 0, 1) 
        
        # Numerical feature extraction
        numerical_features = self.numerical_embed(numerical_data)
        
        combined_features = torch.cat([image_features.mean(dim=0), numerical_features], dim=-1)
        
        attentions = []
        for layer in self.encoder_layers:
            combined_features, attn_weights = layer(combined_features.unsqueeze(0), combined_features.unsqueeze(0), combined_features.unsqueeze(0))
            attentions.append(attn_weights)
        
        hp_pred = self.hp_output(combined_features.mean(dim=0))
        attack_pred = self.attack_output(combined_features.mean(dim=0))
        
        return hp_pred, attack_pred, attentions 


def multi_task_loss(predictions, targets):
    hp_pred, attack_pred = predictions
    hp_target, attack_target = targets

    hp_loss = F.mse_loss(hp_pred.squeeze(), hp_target)
    attack_loss = F.mse_loss(attack_pred.squeeze(), attack_target)
    
    total_loss = hp_loss + attack_loss
    return total_loss


def visualize_attention(image, attention_weights, feature_name, save_path=None):
    seq_len = int(np.sqrt(attention_weights.shape[0]))
    attention_map = attention_weights.mean(0).reshape(seq_len, seq_len)  # Average across heads
    
    attention_map = F.interpolate(
        torch.tensor(attention_map).unsqueeze(0).unsqueeze(0),
        size=image.shape[:2],
        mode="bilinear",
        align_corners=False
    ).squeeze().detach().numpy()
    
    attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
    
    plt.figure(figsize=(8, 8))
    plt.imshow(image)
    plt.imshow(attention_map, cmap="jet", alpha=0.5)
    plt.title(f"Attention Map for {feature_name}")
    plt.axis("off")
    
    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()


def generate_card(model, image, text_embeddings):
    model.eval()
    
    with torch.no_grad():
        image = image.to(device)
        text_embeddings = text_embeddings.to(device)
        
        hp_output, attack_output, attentions = model(image.unsqueeze(0), text_embeddings.unsqueeze(0))
        return hp_output.cpu().numpy(), attack_output.cpu().numpy(), attentions

In [6]:
print("1")
dataset = FinalCardsDataset(final_cards.sample(n=100, random_state=42))
print("2")
batch_size = min(2, len(dataset))
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False)
print("3")

embed_dim = 10
num_heads = 5
num_layers = 2
output_dim = 2
image_size = 32

print("4")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VisualTransformersWithAttention(embed_dim, num_heads, num_layers, output_dim, image_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

print("5")
num_epochs = 3
model.train()
scaler = GradScaler()

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    epoch_loss = 0
    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")
    
    for batch in progress_bar:
        numerical_data = batch['numerical_data']['hp']
        image = batch['large_64x64']
        
        with autocast():
            outputs, attentions = model(image, numerical_data)
            
            loss = criterion(outputs, numerical_data)
        
        epoch_loss += loss.item()
        
        optimizer.zero_grad()
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        progress_bar.set_postfix({"Loss": loss.item()})
    
    avg_epoch_loss = epoch_loss / len(data_loader)
    print(f"Epoch {epoch + 1}, Avg Loss: {avg_epoch_loss:.4f}")

  scaler = GradScaler()


1
Preprocessing data...
Total processed data: 100
2
3
4
5
Epoch 1/3


  with autocast():
Epoch 1/3:   0%|                                                                                | 0/50 [00:00<?, ?it/s]


RuntimeError: Input type (float) and bias type (struct c10::Half) should be the same

In [33]:
def clean_embedding_string(embedding_str):
    return re.sub(r'\s+', ',', embedding_str.strip())

def custom_collate_fn(batch):
    numerical_data = [item['numerical_data']['hp'] for item in batch]
    image_data = [item['large_64x64'] for item in batch]
    
    image_data = torch.stack(image_data, dim=0)

    return {'numerical_data': numerical_data, 'image': image_data}

class FinalCardsDataset(Dataset):
    def __init__(self, dataframe, cache_dir="cache/"):
        self.data = dataframe
        self.image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((64, 64))])
        self.cache_dir = cache_dir
        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir)
        self.preprocessed_data = self._load_or_cache_data()

    def _load_or_cache_data(self):
        cached_data_path = os.path.join(self.cache_dir, "preprocessed_data.pt")
        if os.path.exists(cached_data_path):
            return torch.load(cached_data_path)
        preprocessed_data = self._preprocess_data()
        torch.save(preprocessed_data, cached_data_path)
        return preprocessed_data
    
    def _preprocess_data(self):
        preprocessed_data = []
        for idx, row in self.data.iterrows():
            try:
                numerical_data = torch.tensor([
                    row['id'], row['name'], row['supertype'], row['artist'],
                    row['rarity'], row['set_name'], row['hp']], dtype=torch.float)
                
                large_64x64 = self.image_transform(row['large_64x64'])
                
                preprocessed_data.append({
                    'numerical_data': numerical_data,
                    'large_64x64': large_64x64
                })
            except Exception:
                continue
        return preprocessed_data
    
    def __len__(self):
        return len(self.preprocessed_data)
    
    def __getitem__(self, idx):
        return self.preprocessed_data[idx]

class VisualTransformersWithAttention(nn.Module):
    def __init__(self, embed_dim=128, num_heads=4, num_layers=2, output_dim=1, image_size=32):
        super(VisualTransformersWithAttention, self).__init__()

        self.image_embed = nn.Conv2d(3, embed_dim, kernel_size=3, stride=1, padding=1)  # Ensure output channels match embed_dim
        
        self.encoder_layers = nn.ModuleList(
            [nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) for _ in range(num_layers)]
        )
        
        self.fc_output = nn.Linear(embed_dim, output_dim)
        
    def forward(self, image, numerical_data=None):
        image_features = self.image_embed(image).flatten(2).permute(2, 0, 1)

        for layer in self.encoder_layers:
            image_features = layer(image_features)

        output = self.fc_output(image_features.mean(dim=0))
        return output

def multi_task_loss(predictions, targets):
    hp_pred, attack_pred = predictions
    hp_target, attack_target = targets
    hp_loss = F.mse_loss(hp_pred.squeeze(), hp_target)
    attack_loss = F.mse_loss(attack_pred.squeeze(), attack_target)
    return hp_loss + attack_loss

def visualize_attention(image, attention_weights, feature_name, save_path=None):
    seq_len = int(np.sqrt(attention_weights.shape[0]))
    attention_map = attention_weights.mean(0).reshape(seq_len, seq_len)
    attention_map = F.interpolate(torch.tensor(attention_map).unsqueeze(0).unsqueeze(0), size=image.shape[:2], mode="bilinear", align_corners=False).squeeze().detach().numpy()
    attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
    plt.imshow(image)
    plt.imshow(attention_map, cmap="jet", alpha=0.5)
    plt.title(f"Attention Map for {feature_name}")
    plt.axis("off")
    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()

def generate_card(model, image, text_embeddings):
    model.eval()
    with torch.no_grad():
        image = image.to(device)
        text_embeddings = text_embeddings.to(device)
        hp_output, attack_output, attentions = model(image.unsqueeze(0), text_embeddings.unsqueeze(0))
        return hp_output.cpu().numpy(), attack_output.cpu().numpy(), attentions

In [38]:
dataset = FinalCardsDataset(final_cards.sample(n=100, random_state=42))
batch_size = min(2, len(dataset))
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False, collate_fn=custom_collate_fn)

embed_dim = 128
num_heads = 4
num_layers = 2
output_dim = 1
image_size = 32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = VisualTransformersWithAttention(embed_dim, num_heads, num_layers, output_dim, image_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

num_epochs = 3
model.train()

for epoch in range(num_epochs):
    epoch_loss = 0
    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")
    
    for batch in progress_bar:
        numerical_data = batch['numerical_data'].to(device) 
        image = batch['image'].to(device)
        
        outputs, attentions = model(image, numerical_data)
    
        loss = criterion(outputs, numerical_data)
        
        epoch_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        progress_bar.set_postfix({"Loss": loss.item()})

    avg_epoch_loss = epoch_loss / len(data_loader)
    print(f"Epoch {epoch + 1}, Avg Loss: {avg_epoch_loss:.4f}")


  return torch.load(cached_data_path)
Epoch 1/3:   0%|                                                                                | 0/50 [00:00<?, ?it/s]


AttributeError: 'list' object has no attribute 'to'