In [78]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import re
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
import clip
from tqdm import tqdm

In [5]:
# paths 
DATA_PATH_HEARTHSTONE = '../dataset/Hearthstone-Minion-race/'

In [6]:
df = pd.read_csv(DATA_PATH_HEARTHSTONE + "/train.csv")
df_test = pd.read_csv(DATA_PATH_HEARTHSTONE + "/test.csv")
df_dev = pd.read_csv(DATA_PATH_HEARTHSTONE + '/dev.csv')

In [95]:
df.head()

Unnamed: 0,cardClass,health,id,name,set,attack,cost,rarity,artist,collectible,text,mechanics,race,Image Path,health_text,attack_text,cost_text,collectible_text,combined_text
0,ROGUE,13,Story_06_Tethys,Fleet Admiral Tethys,DARKMOON_FAIRE,1,4,FREE,,,at the end of your turn deal 1 damage to all e...,['TRIGGER_VISUAL'],NONE_race,train_images/Story_06_Tethys.jpg,health: 13,attack: 1,cost: 4,collectible: no,rogue health: 13 darkmoon_faire attack: 1 cost...
1,NEUTRAL,5,EX1_016,Sylvanas Windrunner,EXPERT1,5,6,LEGENDARY,Glenn Rane,1.0,deathrattle take control of a random enemy minion,['DEATHRATTLE'],NONE_race,train_images/EX1_016.jpg,health: 5,attack: 5,cost: 6,collectible: yes,neutral health: 5 expert1 attack: 5 cost: 6 le...
2,NEUTRAL,15,Story_10_IcecrownObelisk,Icecrown Obelisk,STORMWIND,0,3,FREE,,,deathrattle gain control of this minion,['DEATHRATTLE'],NONE_race,train_images/Story_10_IcecrownObelisk.jpg,health: 15,attack: 0,cost: 3,collectible: no,neutral health: 15 stormwind attack: 0 cost: 3...
3,DRUID,5,CORE_CS3_012,Nordrassil Druid,PLACEHOLDER_202204,3,4,RARE,Dave Greco,1.0,battlecry the next spell you cast this turn co...,['BATTLECRY'],NONE_race,train_images/CORE_CS3_012.jpg,health: 5,attack: 3,cost: 4,collectible: yes,druid health: 5 placeholder_202204 attack: 3 c...
4,MAGE,7,BOM_09_Dawngrasp_008t,Dawngrasp,ALTERAC_VALLEY,1,1,FREE,,,freeze any character damaged by this minion re...,['FREEZE'],NONE_race,train_images/BOM_09_Dawngrasp_008t.jpg,health: 7,attack: 1,cost: 1,collectible: no,mage health: 7 alterac_valley attack: 1 cost: ...


In [None]:
df.info()

In [None]:
df.shape

In [None]:
df.isnull().sum()

In [None]:
df['cardClass'].unique()

In [None]:
df['rarity'].unique()

In [None]:
df['race'].unique()

In [None]:
df[df['collectible'].isna()]

In [None]:
df[df['text'].isna()]

In [None]:
df.head()

In [None]:
df['cardClass'].value_counts()

In [None]:
# Plot distribution of card classes
fig = plt.figure(figsize=(12, 5))
df['cardClass'].value_counts().plot(kind='bar')
plt.title('Distribution of Card Classes')
plt.xlabel('Card Class')
plt.ylabel('Count')
plt.show()

In [None]:
grouped_counts = df.groupby(['cardClass', 'rarity']).size().unstack(fill_value=0)

grouped_counts.plot(kind='bar', stacked=True, figsize=(12, 6))

plt.xlabel('Card Class')
plt.ylabel('Count')
plt.title('Distribution of Rarities across Card Classes')

plt.legend(title='Rarity')
plt.grid(True)
plt.show()

In [None]:
grouped_counts = df.groupby(['cardClass', 'race']).size().unstack(fill_value=0)

grouped_counts.plot(kind='bar', stacked=True, figsize=(12, 6))

plt.xlabel('Card Class')
plt.ylabel('Count')
plt.title('Distribution of Rarities across Card Classes')

plt.legend(title='Rarity')
plt.grid(True)
plt.show()

In [None]:
df.head()

In [None]:
df['text'][1]

In [9]:
def get_textual(column, value):
    return column + ": " + str(value)

In [10]:
def get_collectible_text(value):
    if pd.isna(value):
        return "collectible: no"
    return "collectible: yes"

In [11]:
def preprocess_sentence(sentence):
    sentence = re.sub(r'<[^>]+>', '', sentence)
    sentence = re.sub(r'[^\w\s]', '', sentence)
    sentence = sentence.replace('\n', ' ')
    sentence = sentence.lower()
    sentence = ' '.join(sentence.split())
    return sentence

In [12]:
def preprocess_df(df):
    df['cardClass'] = df['cardClass'].fillna('')
    df['health_text'] = df.apply(lambda x: get_textual('health', x['health']), axis=1)
    df['attack_text'] = df.apply(lambda x: get_textual('attack', x['attack']), axis=1)
    df['cost_text'] = df.apply(lambda x: get_textual('cost', x['cost']), axis=1)
    df['collectible_text'] = df['collectible'].apply(get_collectible_text)
    df['text'] = df['text'].fillna('')
    df['text'] = df['text'].apply(preprocess_sentence)
    df['combined_text'] = df['cardClass'].str.lower() + ' ' + df['health_text'] + ' ' + df['set'].str.lower() + ' ' + df['attack_text'] + ' ' + df['cost_text'] + ' ' + df['rarity'].str.lower() + ' ' + df['collectible_text'] + ' ' + df['text']
    return df

In [13]:
df = preprocess_df(df.copy())
df.head()

Unnamed: 0,cardClass,health,id,name,set,attack,cost,rarity,artist,collectible,text,mechanics,race,Image Path,health_text,attack_text,cost_text,collectible_text,combined_text
0,ROGUE,13,Story_06_Tethys,Fleet Admiral Tethys,DARKMOON_FAIRE,1,4,FREE,,,at the end of your turn deal 1 damage to all e...,['TRIGGER_VISUAL'],NONE_race,train_images/Story_06_Tethys.jpg,health: 13,attack: 1,cost: 4,collectible: no,rogue health: 13 darkmoon_faire attack: 1 cost...
1,NEUTRAL,5,EX1_016,Sylvanas Windrunner,EXPERT1,5,6,LEGENDARY,Glenn Rane,1.0,deathrattle take control of a random enemy minion,['DEATHRATTLE'],NONE_race,train_images/EX1_016.jpg,health: 5,attack: 5,cost: 6,collectible: yes,neutral health: 5 expert1 attack: 5 cost: 6 le...
2,NEUTRAL,15,Story_10_IcecrownObelisk,Icecrown Obelisk,STORMWIND,0,3,FREE,,,deathrattle gain control of this minion,['DEATHRATTLE'],NONE_race,train_images/Story_10_IcecrownObelisk.jpg,health: 15,attack: 0,cost: 3,collectible: no,neutral health: 15 stormwind attack: 0 cost: 3...
3,DRUID,5,CORE_CS3_012,Nordrassil Druid,PLACEHOLDER_202204,3,4,RARE,Dave Greco,1.0,battlecry the next spell you cast this turn co...,['BATTLECRY'],NONE_race,train_images/CORE_CS3_012.jpg,health: 5,attack: 3,cost: 4,collectible: yes,druid health: 5 placeholder_202204 attack: 3 c...
4,MAGE,7,BOM_09_Dawngrasp_008t,Dawngrasp,ALTERAC_VALLEY,1,1,FREE,,,freeze any character damaged by this minion re...,['FREEZE'],NONE_race,train_images/BOM_09_Dawngrasp_008t.jpg,health: 7,attack: 1,cost: 1,collectible: no,mage health: 7 alterac_valley attack: 1 cost: ...


In [14]:
def encode_labels(labels, label_encoder):
    labels = label_encoder.fit_transform(labels)
    one_hot_labels = np.zeros((len(labels), 15))
    one_hot_labels[np.arange(len(labels)), labels] = 1
    return one_hot_labels

In [15]:
label_encoder = LabelEncoder()
labels = df['race']
train_labels = encode_labels(labels, label_encoder)

In [16]:
class_mapping = dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))
class_mapping

{'BEAST': 0,
 'DEMON': 1,
 'DRAGON': 2,
 'ELEMENTAL': 3,
 'HUMAN': 4,
 'MECHANICAL': 5,
 'MURLOC': 6,
 'NAGA': 7,
 'NONE_race': 8,
 'OLDGOD': 9,
 'ORC': 10,
 'PIRATE': 11,
 'QUILBOAR': 12,
 'TAUREN': 13,
 'TOTEM': 14}

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, preprocess = clip.load("ViT-B/32", device=device)

In [19]:
def combine_image_text(texts, images, clip_model, clip_preprocess):
    def segment_text_func(text, max_segment_length=100):
        # Split text into segments of maximum length
        segments = []
        for i in range(0, len(text), max_segment_length):
            segment = text[i:i + max_segment_length]
            segments.append(segment)
        return segments
    
    image_text = []
    for idx in range(len(texts)):
        text = texts[idx]
        image = images[idx]
        
        # Process text in segments
        segments = segment_text_func(text)
        
        # Encode each segment separately
        segment_features = []
        for segment in segments:
            segment_text = clip.tokenize(segment).to(device)
            with torch.no_grad():
                segment_features.append(clip_model.encode_text(segment_text))
        
        # Aggregate segment representations
        aggregated_text_features = aggregate_features(segment_features)
        
        # Preprocess image
        image = clip_preprocess(Image.open(image)).unsqueeze(0).to(device)
        
        # Encode image
        with torch.no_grad():
            image_features = clip_model.encode_image(image)
        
        # Combine text and image features
        combined_features = torch.cat((aggregated_text_features, image_features), 1)
        image_text.append(combined_features)
    return image_text

def aggregate_features(features):
    # Aggregate segment features (e.g., by averaging)
    return torch.mean(torch.stack(features), dim=0)

In [20]:
texts = list(df['combined_text'])
images = [DATA_PATH_HEARTHSTONE + img for img in list(df['Image Path'])]
data = combine_image_text(texts, images, model, preprocess)

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


In [29]:
data[0]

tensor([[ 0.0847,  0.2734, -0.0474,  ...,  0.1588, -0.1272,  0.2756]],
       device='cuda:0', dtype=torch.float16)

In [21]:
df_dev = preprocess_df(df_dev.copy())
dev_labels = encode_labels(df_dev['race'], label_encoder)

In [22]:
dev_texts = list(df_dev['combined_text'])
dev_images = [DATA_PATH_HEARTHSTONE + img for img in list(df_dev['Image Path'])]
data_dev = combine_image_text(dev_texts, dev_images, model, preprocess)

In [72]:
class CustomDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        feature = self.features[idx].clone().detach().to(torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return feature, label

In [73]:
class Model(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(Model, self).__init__()
        self.fc_1 = nn.Linear(input_size, hidden_size)
        self.fc_2 = nn.Linear(hidden_size, hidden_size)
        self.fc_3 = nn.Linear(hidden_size, num_classes)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = x.view(x.size(0), -1) 
        out = self.relu(self.fc_1(x))
        out = self.relu(self.fc_2(out))
        out = self.fc_3(out)
        return out

In [74]:
input_size = data[0].size(1)
hidden_size = 512
num_classes = len(label_encoder.classes_)
num_epochs = 15
lr = 0.001
batch_size = 8

lin_model = Model(input_size, hidden_size, num_classes).to(device)
optimizer = optim.Adam(lin_model.parameters(), lr=lr)
criterion = nn.BCEWithLogitsLoss()

In [75]:
train_dataset = CustomDataset(data, train_labels)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_dataset = CustomDataset(data_dev, dev_labels)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size, shuffle=True)

In [84]:
def train():
    for epoch in range(num_epochs):
        total_loss = 0.0
        for feature, label in train_loader:
            lin_model.train()
            feature = feature.to(device)
            label = label.to(device)
            output = lin_model(feature)
            loss = criterion(output, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            
            lin_model.eval()
            with torch.no_grad():
                val_loss = 0.0
                for dev_feature, dev_labels in dev_loader:
                    dev_feature = dev_feature.to(device)
                    dev_labels = dev_labels.to(device)
                    output = lin_model(dev_feature)
                    dev_loss = criterion(output, dev_labels)
                    val_loss += dev_loss.item()
        train_loss = total_loss / len(train_loader)
        val_loss /= len(dev_loader)

        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Validation Loss: {val_loss:.4f}")

In [85]:
train()

100%|████████████████████████████████████████████████████████████████████████████████| 675/675 [00:46<00:00, 14.61it/s]
  0%|                                                                                  | 1/675 [00:00<01:16,  8.80it/s]

Epoch [1/15]
  Train Loss: 0.0696
  Validation Loss: 0.0838


100%|████████████████████████████████████████████████████████████████████████████████| 675/675 [01:12<00:00,  9.29it/s]
  0%|                                                                                  | 1/675 [00:00<01:07,  9.96it/s]

Epoch [2/15]
  Train Loss: 0.0537
  Validation Loss: 0.0755


100%|████████████████████████████████████████████████████████████████████████████████| 675/675 [01:08<00:00,  9.83it/s]
  0%|                                                                                  | 1/675 [00:00<01:09,  9.67it/s]

Epoch [3/15]
  Train Loss: 0.0436
  Validation Loss: 0.0940


100%|████████████████████████████████████████████████████████████████████████████████| 675/675 [01:08<00:00,  9.87it/s]
  0%|                                                                                  | 1/675 [00:00<01:14,  9.07it/s]

Epoch [4/15]
  Train Loss: 0.0350
  Validation Loss: 0.0823


100%|████████████████████████████████████████████████████████████████████████████████| 675/675 [01:19<00:00,  8.44it/s]
  0%|                                                                                  | 1/675 [00:00<01:23,  8.05it/s]

Epoch [5/15]
  Train Loss: 0.0269
  Validation Loss: 0.0913


100%|████████████████████████████████████████████████████████████████████████████████| 675/675 [01:23<00:00,  8.08it/s]
  0%|                                                                                  | 1/675 [00:00<01:20,  8.36it/s]

Epoch [6/15]
  Train Loss: 0.0207
  Validation Loss: 0.1163


100%|████████████████████████████████████████████████████████████████████████████████| 675/675 [01:23<00:00,  8.09it/s]
  0%|                                                                                  | 1/675 [00:00<01:16,  8.75it/s]

Epoch [7/15]
  Train Loss: 0.0184
  Validation Loss: 0.1003


100%|████████████████████████████████████████████████████████████████████████████████| 675/675 [01:21<00:00,  8.26it/s]
  0%|                                                                                  | 1/675 [00:00<01:26,  7.77it/s]

Epoch [8/15]
  Train Loss: 0.0136
  Validation Loss: 0.1146


100%|████████████████████████████████████████████████████████████████████████████████| 675/675 [01:24<00:00,  7.98it/s]
  0%|                                                                                  | 1/675 [00:00<01:21,  8.29it/s]

Epoch [9/15]
  Train Loss: 0.0112
  Validation Loss: 0.1498


100%|████████████████████████████████████████████████████████████████████████████████| 675/675 [01:22<00:00,  8.16it/s]
  0%|                                                                                  | 1/675 [00:00<01:25,  7.88it/s]

Epoch [10/15]
  Train Loss: 0.0096
  Validation Loss: 0.1060


100%|████████████████████████████████████████████████████████████████████████████████| 675/675 [01:22<00:00,  8.17it/s]
  0%|                                                                                  | 1/675 [00:00<01:22,  8.17it/s]

Epoch [11/15]
  Train Loss: 0.0101
  Validation Loss: 0.1059


 95%|████████████████████████████████████████████████████████████████████████████▎   | 644/675 [01:18<00:03,  8.21it/s]


KeyboardInterrupt: 

In [58]:
df_test = preprocess_df(df_test.copy())
df_test.head()

Unnamed: 0,cardClass,health,id,name,set,attack,cost,rarity,artist,collectible,text,mechanics,race,Image Path,health_text,attack_text,cost_text,collectible_text,combined_text
0,ROGUE,3,EX1_191,Plaguebringer,LEGACY,3,4,FREE,J. Axer,1.0,battlecry give a friendly minion poisonous,['BATTLECRY'],NONE_race,test_images/EX1_191.jpg,health: 3,attack: 3,cost: 4,collectible: yes,rogue health: 3 legacy attack: 3 cost: 4 free ...
1,PRIEST,3,REV_246,Mysterious Visitor,REVENDRETH,2,2,EPIC,Arthur Bozonnet,1.0,battlecry reduce the cost of cards copied from...,['BATTLECRY'],NONE_race,test_images/REV_246.jpg,health: 3,attack: 2,cost: 2,collectible: yes,priest health: 3 revendreth attack: 2 cost: 2 ...
2,PRIEST,1,CRED_98,Giovanni Scarpati,CREDITS,3,7,LEGENDARY,,,xdivine shield divine shield divine shield div...,,NONE_race,test_images/CRED_98.jpg,health: 1,attack: 3,cost: 7,collectible: no,priest health: 1 credits attack: 3 cost: 7 leg...
3,NEUTRAL,3,Prologue_UnstableEnergy3,Fel Unstable Energy,BLACK_TEMPLE,5,3,FREE,,,cant attack deathrattle deal 2 damage to all m...,"['CANT_ATTACK', 'DEATHRATTLE']",NONE_race,test_images/Prologue_UnstableEnergy3.jpg,health: 3,attack: 5,cost: 3,collectible: no,neutral health: 3 black_temple attack: 5 cost:...
4,PALADIN,5,TRLA_138,Shirvallah's Grace,TROLL,0,0,FREE,Danny Dai,,xshrine after you cast a spell on a friendly m...,"['TRIGGER_VISUAL', 'InvisibleDeathrattle']",NONE_race,test_images/TRLA_138.jpg,health: 5,attack: 0,cost: 0,collectible: no,paladin health: 5 troll attack: 0 cost: 0 free...


In [61]:
test_labels = df_test['race']
test_labels = encode_labels(test_labels, label_encoder)

In [63]:
texts = list(df_test['combined_text'])
images = [DATA_PATH_HEARTHSTONE + img for img in list(df_test['Image Path'])]
test_data = combine_image_text(texts, images, model, preprocess)

In [88]:
test_dataset = CustomDataset(test_data, test_labels)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [96]:
def evaluate(model, test_loader):
    total_accuracy = 0.0
    with torch.no_grad():
        for feature, label in test_loader:
            print('feature:', feature)
            feature = feature.to(device)
            output = model(feature).cpu().numpy()
            print('output:', output)
            pred = np.array([np.argmax(out) for out in output])
            label = label.cpu().numpy()
            print('label: ', label)
            label = np.array([np.argmax(l) for l in label])
            accuracy = np.sum(pred == label)
            accuracy /= len(pred)
            total_accuracy += accuracy
            break
    print(f"Accuracy: {total_accuracy/len(test_loader):.3f}")

In [107]:
def evaluate_accuracy(model, test_loader):
    total_accuracy = 0.0
    total_samples = 0
    with torch.no_grad():
        for feature, label in test_loader:
            feature = feature.to(device)
            label = label.to(device)
            output = model(feature)
            probabilities = torch.softmax(output, dim=1)
            _, predicted = torch.max(probabilities, 1)
            correct = (predicted == torch.argmax(label, dim=1)).sum().item()
            total_accuracy += correct
            total_samples += label.size(0)
    accuracy = total_accuracy / total_samples
    print(f"Accuracy: {accuracy:.3f}")

In [114]:
import torch.nn.functional as F

def evaluate_log_loss(model, test_loader):
    total_loss = 0.0
    with torch.no_grad():
        for feature, label in test_loader:
            feature = feature.to(device)
            label = label.to(device)
            output = model(feature)
            
            loss = F.binary_cross_entropy_with_logits(output, label)
            total_loss += loss.item()
    
    average_loss = total_loss / len(test_loader)
    
    print(f"Log Loss: {average_loss:.3f}")

In [115]:
evaluate_accuracy(lin_model, test_loader)

Accuracy: 0.835


In [116]:
evaluate_log_loss(lin_model, test_loader)

Log Loss: 0.106


In [None]:
def combine_image_text(texts, images, labels, clip_model, clip_preprocess):
    image_text = []
    label_list = []
    for idx in range(len(df)):
        text = texts[idx]
        image = IMAGE_PATH + images[idx]
        print(text)
        text = clip.tokenize(text).to(device)
        image = preprocess(Image.open(image)).unsqueeze(0).to(device)

        with torch.no_grad():
            text_features = model.encode_text(text)
            image_features = model.encode_image(image)
        combined_features = torch.cat((text_features), image_features), 1))
        image_text.append(combined_features)
    return image_text