In [1]:
import pandas as pd
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
import numpy as np
import string
import spacy
from transformers import DistilBertTokenizer, DistilBertModel
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO
from tqdm import tqdm
import concurrent.futures
import time
import random
import ast
import matplotlib.pyplot as plt
import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

le = LabelEncoder()

scaler = MinMaxScaler()

spacy.prefer_gpu() 

nlp = spacy.load('en_core_web_sm', disable=["ner", "parser", "tagger"])

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
model.to(device)

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): DistilBertSdpaAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): L

In [2]:
raw_cards1 = pd.read_csv('https://raw.githubusercontent.com/tooniez/pokemoncards_tcg/main/cards.csv')
cards1 = raw_cards1[['id', 'name', 'supertype', 'subtypes', 'hp', 'types' , 'artist', 'rarity', 'large_image_source']].copy()

raw_cards2 = pd.read_parquet("hf://datasets/bhavnicksm/PokemonCardsPlus/data/train-00000-of-00001-015be9de37300028.parquet")
cards2 = raw_cards2[['id', 'pokemon_image', 'caption', 'pokemon_intro', 'set_name']].copy()

merged_cards = pd.merge(cards1, cards2, on='id', how='inner')

shared_records_count = len(merged_cards)

print(f'Number of shared records: {shared_records_count}')

merged_cards = merged_cards.dropna()

Number of shared records: 13139


In [3]:
merged_cards.info()

<class 'pandas.core.frame.DataFrame'>
Index: 12373 entries, 0 to 13138
Data columns (total 13 columns):
 #   Column              Non-Null Count  Dtype  
---  ------              --------------  -----  
 0   id                  12373 non-null  object 
 1   name                12373 non-null  object 
 2   supertype           12373 non-null  object 
 3   subtypes            12373 non-null  object 
 4   hp                  12373 non-null  float64
 5   types               12373 non-null  object 
 6   artist              12373 non-null  object 
 7   rarity              12373 non-null  object 
 8   large_image_source  12373 non-null  object 
 9   pokemon_image       12373 non-null  object 
 10  caption             12373 non-null  object 
 11  pokemon_intro       12373 non-null  object 
 12  set_name            12373 non-null  object 
dtypes: float64(1), object(12)
memory usage: 1.3+ MB


In [4]:
merged_cards.head()

Unnamed: 0,id,name,supertype,subtypes,hp,types,artist,rarity,large_image_source,pokemon_image,caption,pokemon_intro,set_name
0,base1-1,Alakazam,Pokémon,"[""Stage 2""]",80.0,"[""Psychic""]",Ken Sugimori,Rare Holo,"""https://images.pokemontcg.io/base1/1_hires.png""",https://img.pokemondb.net/artwork/alakazam.jpg,A Stage 2 Pokemon Card of type Psychic with th...,Alakazam is a Psychic type Pokémon introduced ...,Base
1,base1-2,Blastoise,Pokémon,"[""Stage 2""]",100.0,"[""Water""]",Ken Sugimori,Rare Holo,"""https://images.pokemontcg.io/base1/2_hires.png""",https://img.pokemondb.net/artwork/blastoise.jpg,A Stage 2 Pokemon Card of type Water with the ...,Blastoise is a Water type Pokémon introduced i...,Base
2,base1-3,Chansey,Pokémon,"[""Basic""]",120.0,"[""Colorless""]",Ken Sugimori,Rare Holo,"""https://images.pokemontcg.io/base1/3_hires.png""",https://img.pokemondb.net/artwork/chansey.jpg,A Basic Pokemon Card of type Colorless with th...,Chansey is a Normal type Pokémon introduced in...,Base
3,base1-4,Charizard,Pokémon,"[""Stage 2""]",120.0,"[""Fire""]",Mitsuhiro Arita,Rare Holo,"""https://images.pokemontcg.io/base1/4_hires.png""",https://img.pokemondb.net/artwork/charizard.jpg,A Stage 2 Pokemon Card of type Fire with the t...,Charizard is a Fire/Flying type Pokémon introd...,Base
4,base1-5,Clefairy,Pokémon,"[""Basic""]",40.0,"[""Colorless""]",Ken Sugimori,Rare Holo,"""https://images.pokemontcg.io/base1/5_hires.png""",https://img.pokemondb.net/artwork/clefairy.jpg,A Basic Pokemon Card of type Colorless with th...,Clefairy is a Fairy type Pokémon introduced in...,Base


In [5]:
categorical_columns = ['name', 'supertype', 'artist', 'rarity', 'set_name', 'id']

for col in categorical_columns:
    merged_cards[col] = le.fit_transform(merged_cards[col])

def encode_array_column(column):
    merged_cards[column] = merged_cards[column].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    
    exploded = merged_cards[column].explode().dropna().unique()
    
    le.fit(exploded)
    def encode_list(x):
        if isinstance(x, list):
            return [le.transform([item])[0] for item in x if isinstance(item, str)]
        return []
    
    merged_cards[column] = merged_cards[column].apply(encode_list)

encode_array_column('subtypes')
encode_array_column('types')

merged_cards['hp'] = scaler.fit_transform(merged_cards[['hp']])

def upscale_image(url, size=(64, 64), retries=10, timeout=10):
    if isinstance(url, str):
        cleaned_url = url.strip('"').strip("'")
        
        for attempt in range(retries):
            try:
                response = requests.get(cleaned_url, timeout=timeout)
                response.raise_for_status()
                img = Image.open(BytesIO(response.content))
                img = img.resize(size)
                return img
            except (requests.exceptions.RequestException, Exception) as e:
                time.sleep(random.uniform(1, 3))
                if attempt == retries - 1:
                    return None
    else:
        return None
        
def upscale_images_with_progress_parallel(df, column_name, size=(64, 64)):
    results = []
    with tqdm(total=len(df), desc=f"Upscaling {column_name}", unit="image") as pbar:
        with concurrent.futures.ThreadPoolExecutor() as executor:
            future_to_url = {executor.submit(upscale_image, url, size): url for url in df[column_name]}
            
            for future in concurrent.futures.as_completed(future_to_url):
                results.append(future.result())
                pbar.update(1)

    return results

merged_cards['large_64x64'] = upscale_images_with_progress_parallel(merged_cards, 'large_image_source', size=(64, 64))
merged_cards['pokemon_64x64'] = upscale_images_with_progress_parallel(merged_cards, 'pokemon_image', size=(64, 64))

def preprocess_text_spacy(text):
    """Preprocess text using spaCy: lowercase, remove punctuation, and lemmatize"""
    text = text.lower()
    text = text.translate(str.maketrans('', '', string.punctuation))
    doc = nlp(text)
    text = ' '.join([token.lemma_ for token in doc])
    return text

def get_distilbert_embeddings(text):
    """Get DistilBERT embeddings for the provided text"""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    inputs = {key: value.to(device) for key, value in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
    return embeddings

def apply_parallel_thread(df, func, column_name):
    """Helper function to apply a function in parallel with tqdm for progress tracking"""
    results = []
    with tqdm(total=len(df), desc=f"Processing {column_name}", unit="row") as pbar:
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = {executor.submit(func, row): row for row in df[column_name]}
            for future in futures:
                results.append(future.result())
                pbar.update(1)
    return results

def apply_parallel_embeddings(df, func, column_name):
    """Helper function to apply DistilBERT embeddings in parallel"""
    results = []
    with tqdm(total=len(df), desc=f"Processing embeddings {column_name}", unit="row") as pbar:
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = {executor.submit(func, row): row for row in df[column_name]}
            for future in futures:
                results.append(future.result())
                pbar.update(1)
    return results

for col in ['caption', 'pokemon_intro']:
    merged_cards[col] = apply_parallel_thread(merged_cards, preprocess_text_spacy, col)
    merged_cards[col + '_embeddings'] = apply_parallel_embeddings(merged_cards, get_distilbert_embeddings, col)

final_cards = merged_cards[['id', 'name', 'supertype', 'subtypes', 'hp', 'types', 'artist', 'rarity', 'set_name', 'large_64x64', 'pokemon_64x64', 'caption_embeddings', 'pokemon_intro_embeddings']]
final_cards = final_cards.dropna()

Upscaling large_image_source: 100%|███████████████████████████████████████████| 12373/12373 [02:06<00:00, 98.12image/s]
Upscaling pokemon_image: 100%|████████████████████████████████████████████████| 12373/12373 [04:53<00:00, 42.11image/s]
Processing caption: 100%|██████████████████████████████████████████████████████| 12373/12373 [01:36<00:00, 127.96row/s]
Processing embeddings caption: 100%|███████████████████████████████████████████| 12373/12373 [00:44<00:00, 277.74row/s]
Processing pokemon_intro: 100%|████████████████████████████████████████████████| 12373/12373 [01:24<00:00, 146.61row/s]
Processing embeddings pokemon_intro: 100%|█████████████████████████████████████| 12373/12373 [00:33<00:00, 367.39row/s]


In [6]:
final_cards.head()

Unnamed: 0,id,name,supertype,subtypes,hp,types,artist,rarity,set_name,large_64x64,pokemon_64x64,caption_embeddings,pokemon_intro_embeddings
0,0,22,0,[17],0.16129,[9],85,7,8,<PIL.Image.Image image mode=RGB size=64x64 at ...,<PIL.Image.Image image mode=RGB size=64x64 at ...,"[-0.3966347, -0.18685628, 0.27757892, 0.011523...","[-0.5267142, -0.46775416, 0.29764858, 0.103406..."
1,11,178,0,[17],0.225806,[10],85,7,8,<PIL.Image.Image image mode=RGB size=64x64 at ...,<PIL.Image.Image image mode=RGB size=64x64 at ...,"[-0.35139668, -0.1253799, 0.3185609, 0.1212850...","[-0.36563793, -0.089610875, 0.17557997, 0.1113..."
2,22,285,0,[1],0.290323,[0],85,7,8,<PIL.Image.Image image mode=RGB size=64x64 at ...,<PIL.Image.Image image mode=RGB size=64x64 at ...,"[-0.4593701, -0.15428229, 0.28698304, 0.065975...","[-0.48211092, -0.70858437, 0.112857044, 0.0452..."
3,33,287,0,[17],0.290323,[5],118,7,8,<PIL.Image.Image image mode=RGB size=64x64 at ...,<PIL.Image.Image image mode=RGB size=64x64 at ...,"[-0.4227795, -0.13398156, 0.26598167, 0.058524...","[-0.3931246, -0.11484864, 0.264527, 0.07705603..."
4,44,331,0,[1],0.032258,[0],85,7,8,<PIL.Image.Image image mode=RGB size=64x64 at ...,<PIL.Image.Image image mode=RGB size=64x64 at ...,"[-0.47995847, -0.13535596, 0.3018123, 0.051485...","[-0.32773325, -0.454511, 0.2783921, 0.00911339..."


In [7]:
final_cards.info()

<class 'pandas.core.frame.DataFrame'>
Index: 12369 entries, 0 to 13138
Data columns (total 13 columns):
 #   Column                    Non-Null Count  Dtype  
---  ------                    --------------  -----  
 0   id                        12369 non-null  int64  
 1   name                      12369 non-null  int64  
 2   supertype                 12369 non-null  int64  
 3   subtypes                  12369 non-null  object 
 4   hp                        12369 non-null  float64
 5   types                     12369 non-null  object 
 6   artist                    12369 non-null  int64  
 7   rarity                    12369 non-null  int64  
 8   set_name                  12369 non-null  int64  
 9   large_64x64               12369 non-null  object 
 10  pokemon_64x64             12369 non-null  object 
 11  caption_embeddings        12369 non-null  object 
 12  pokemon_intro_embeddings  12369 non-null  object 
dtypes: float64(1), int64(6), object(6)
memory usage: 1.3+ MB


In [8]:
final_cards['subtypes'] = final_cards['subtypes'].apply(lambda x: [int(i) for i in x])
final_cards['types'] = final_cards['types'].apply(lambda x: [int(i) for i in x])

def image_to_flat_array(image):
    if isinstance(image, Image.Image):
        image = image.convert('RGB')
        image_array = np.array(image)
        flattened_array = image_array.flatten()
        return flattened_array.tolist()
    else:
        raise ValueError("Input should be a PIL Image object")

final_cards['large_64x64'] = final_cards['large_64x64'].apply(image_to_flat_array)
final_cards['pokemon_64x64'] = final_cards['pokemon_64x64'].apply(image_to_flat_array)


In [9]:
final_cards.to_csv("preprocessed_data.csv")

In [10]:
final_cards.head()

Unnamed: 0,id,name,supertype,subtypes,hp,types,artist,rarity,set_name,large_64x64,pokemon_64x64,caption_embeddings,pokemon_intro_embeddings
0,0,22,0,[17],0.16129,[9],85,7,8,"[246, 213, 68, 247, 212, 66, 248, 214, 61, 249...","[255, 255, 255, 255, 255, 255, 255, 255, 255, ...","[-0.3966347, -0.18685628, 0.27757892, 0.011523...","[-0.5267142, -0.46775416, 0.29764858, 0.103406..."
1,11,178,0,[17],0.225806,[10],85,7,8,"[243, 208, 61, 243, 207, 59, 243, 206, 54, 243...","[255, 255, 255, 255, 255, 255, 255, 255, 255, ...","[-0.35139668, -0.1253799, 0.3185609, 0.1212850...","[-0.36563793, -0.089610875, 0.17557997, 0.1113..."
2,22,285,0,[1],0.290323,[0],85,7,8,"[234, 208, 72, 235, 208, 70, 237, 210, 68, 236...","[255, 255, 255, 255, 255, 255, 255, 255, 255, ...","[-0.4593701, -0.15428229, 0.28698304, 0.065975...","[-0.48211092, -0.70858437, 0.112857044, 0.0452..."
3,33,287,0,[17],0.290323,[5],118,7,8,"[243, 212, 67, 243, 212, 64, 244, 212, 59, 245...","[255, 255, 255, 255, 255, 255, 255, 255, 255, ...","[-0.4227795, -0.13398156, 0.26598167, 0.058524...","[-0.3931246, -0.11484864, 0.264527, 0.07705603..."
4,44,331,0,[1],0.032258,[0],85,7,8,"[242, 207, 63, 240, 205, 62, 242, 206, 59, 243...","[255, 255, 255, 255, 255, 255, 255, 255, 255, ...","[-0.47995847, -0.13535596, 0.3018123, 0.051485...","[-0.32773325, -0.454511, 0.2783921, 0.00911339..."
