In [25]:
import pandas as pd
import numpy as np
from models.gan import GAN
from sklearn.preprocessing import StandardScaler

In [26]:
df_pokemon= pd.read_csv("data/Pokemon.csv")
df_pokemon.head()

Unnamed: 0,#,Name,Type 1,Type 2,Total,HP,Attack,Defense,Sp. Atk,Sp. Def,Speed,Generation,Legendary
0,1,Bulbasaur,Grass,Poison,318,45,49,49,65,65,45,1,False
1,2,Ivysaur,Grass,Poison,405,60,62,63,80,80,60,1,False
2,3,Venusaur,Grass,Poison,525,80,82,83,100,100,80,1,False
3,3,VenusaurMega Venusaur,Grass,Poison,625,80,100,123,122,120,80,1,False
4,4,Charmander,Fire,,309,39,52,43,60,50,65,1,False


In [27]:
df_pokemon = df_pokemon.drop(columns=['Name', 'Total', '#'], axis=1)
mean_values = df_pokemon.select_dtypes(include=[np.number]).mean().to_list()[:-1]
std_values = df_pokemon.select_dtypes(include=[np.number]).std().to_list()[:-1]

[69.25875, 79.00125, 73.8425, 72.82, 71.9025, 68.2775]


In [28]:
numerical_cols = ['HP', 'Attack', 'Defense', 'Sp. Atk', 'Sp. Def', 'Speed']
scaler = StandardScaler()
df_pokemon[numerical_cols] = scaler.fit_transform(df_pokemon[numerical_cols])
df_pokemon['Type 2'] = df_pokemon['Type 2'].fillna('None')
df_pokemon = pd.get_dummies(df_pokemon, columns=['Type 1', 'Type 2', 'Generation'])
bool_cols = df_pokemon.select_dtypes(include=['bool']).columns
df_pokemon[bool_cols] = df_pokemon[bool_cols].astype(int)
column_names = df_pokemon.columns
df_pokemon.head()

Unnamed: 0,HP,Attack,Defense,Sp. Atk,Sp. Def,Speed,Legendary,Type 1_Bug,Type 1_Dark,Type 1_Dragon,...,Type 2_Psychic,Type 2_Rock,Type 2_Steel,Type 2_Water,Generation_1,Generation_2,Generation_3,Generation_4,Generation_5,Generation_6
0,-0.950626,-0.924906,-0.797154,-0.23913,-0.248189,-0.801503,0,0,0,0,...,0,0,0,0,1,0,0,0,0,0
1,-0.362822,-0.52413,-0.347917,0.21956,0.291156,-0.285015,0,0,0,0,...,0,0,0,0,1,0,0,0,0,0
2,0.420917,0.092448,0.293849,0.831146,1.010283,0.403635,0,0,0,0,...,0,0,0,0,1,0,0,0,0,0
3,0.420917,0.647369,1.577381,1.503891,1.729409,0.403635,0,0,0,0,...,0,0,0,0,1,0,0,0,0,0
4,-1.185748,-0.832419,-0.989683,-0.392027,-0.787533,-0.112853,0,0,0,0,...,0,0,0,0,1,0,0,0,0,0


Train GAN model

In [29]:
noise = np.random.normal(0, 1, df_pokemon.shape) 
gan = GAN(data=df_pokemon, noise_dim=100, epochs=10, batch_size=32)
generator = gan.create_generator()
discriminator = gan.create_discriminator()
gan_model = gan.compile(generator=generator, discriminator=discriminator)
trained_gan = gan.train(generator=generator,discriminator=discriminator, gan=gan_model)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step
64
64


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


>1, d1=0.675, d2=1.113
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
64
64
>2, d1=0.677, d2=1.104
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
64
64
>3, d1=0.680, d2=1.092
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
64
64
>4, d1=0.677, d2=1.076
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
64
64
>5, d1=0.676, d2=1.066
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step
64
64
>6, d1=0.670, d2=1.055
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step
64
64
>7, d1=0.665, d2=1.044
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step
64
64
>8, d1=0.664, d2=1.035
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step
64
64
>9, d1=0.661, d2=1.026
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step
64
64
>10, d1=0.658, d2=1.014


Get predicitions of GAN model

In [30]:
generate_num = 10
noise = np.random.normal(0, 1, (generate_num, 100))
predicted_data = trained_gan.predict(noise)
predicted_df = pd.DataFrame(predicted_data, columns=column_names)
predicted_df.head(10)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step


Unnamed: 0,HP,Attack,Defense,Sp. Atk,Sp. Def,Speed,Legendary,Type 1_Bug,Type 1_Dark,Type 1_Dragon,...,Type 2_Psychic,Type 2_Rock,Type 2_Steel,Type 2_Water,Generation_1,Generation_2,Generation_3,Generation_4,Generation_5,Generation_6
0,0.406082,0.373249,0.516728,0.621164,0.742968,0.284748,0.605653,0.644141,0.521543,0.451696,...,0.381937,0.481619,0.764559,0.281027,0.353335,0.653641,0.5287,0.309845,0.588003,0.553046
1,0.265173,0.5692,0.367911,0.491966,0.722766,0.365292,0.598924,0.596959,0.556552,0.518574,...,0.310155,0.448783,0.790058,0.190797,0.250109,0.721243,0.564512,0.449874,0.529218,0.499743
2,0.374451,0.393847,0.336887,0.618139,0.782437,0.278524,0.760481,0.502319,0.636988,0.455954,...,0.369975,0.536797,0.644022,0.407053,0.286001,0.530842,0.492748,0.379486,0.634569,0.527479
3,0.36662,0.455545,0.675881,0.576883,0.790718,0.174296,0.592982,0.46634,0.464822,0.499802,...,0.298254,0.442285,0.915263,0.327149,0.329549,0.737293,0.528428,0.335883,0.479136,0.371212
4,0.557625,0.605417,0.229444,0.320429,0.854406,0.449371,0.685158,0.253377,0.614825,0.625935,...,0.27507,0.436369,0.519665,0.251403,0.19764,0.482379,0.574996,0.44471,0.545711,0.401397
5,0.310299,0.43018,0.406046,0.673561,0.749702,0.30115,0.75978,0.544736,0.671721,0.448916,...,0.291993,0.528329,0.655573,0.339409,0.321498,0.602042,0.458117,0.475288,0.611121,0.485113
6,0.450337,0.509126,0.450063,0.474237,0.660819,0.332634,0.58138,0.452122,0.546261,0.512283,...,0.39552,0.47257,0.619938,0.365586,0.313844,0.52908,0.577451,0.405206,0.476163,0.366519
7,0.416506,0.475272,0.42543,0.585823,0.633514,0.38451,0.504382,0.691613,0.517157,0.405063,...,0.417499,0.429642,0.608821,0.310921,0.398813,0.549834,0.387241,0.418495,0.544369,0.495101
8,0.168342,0.441453,0.424031,0.707217,0.663815,0.167093,0.513924,0.857177,0.646477,0.526013,...,0.215836,0.479456,0.756565,0.198961,0.258481,0.677185,0.405516,0.410854,0.54392,0.379122
9,0.337369,0.462737,0.326222,0.491315,0.74619,0.403332,0.574154,0.704448,0.510218,0.486689,...,0.359022,0.411726,0.637788,0.182588,0.297383,0.563493,0.474241,0.329861,0.459862,0.490838


Transform prediction to human readable output

In [31]:
def make_binary_attributes_readable(attributes:list) -> pd.DataFrame:
    for attribute in attributes:
        columns = [col for col in predicted_df.columns if attribute in col]

        # Get the column with the highest value for each row among 'Type 1' columns
        max = predicted_df[columns].idxmax(axis=1)

        # Extract the type name from the column names (removing the 'Type 1_' prefix)
        predicted_df[attribute[:-1]] = max.str.replace(attribute, '')

        # Optionally, you can drop the old type columns if they are no longer needed
        predicted_df.drop(columns=columns, inplace=True)
        predicted_df.head()
    return predicted_df

In [32]:
make_readble_cols= ["Generation_", "Type 1_", "Type 2_",]
readable_df = make_binary_attributes_readable(make_readble_cols)
readable_df['Legendary'] = readable_df['Legendary'] > 0.5 
readable_df[numerical_cols] = scaler.inverse_transform(readable_df[numerical_cols])
readable_df.head(10)

Unnamed: 0,HP,Attack,Defense,Sp. Atk,Sp. Def,Speed,Legendary,Generation,Type 1,Type 2
0,0.406082,0.373249,0.516728,0.621164,0.742968,0.284748,0.605653,2,Fairy,Steel
1,0.265173,0.5692,0.367911,0.491966,0.722766,0.365292,0.598924,2,Fairy,Flying
2,0.374451,0.393847,0.336887,0.618139,0.782437,0.278524,0.760481,5,Rock,
3,0.36662,0.455545,0.675881,0.576883,0.790718,0.174296,0.592982,2,Fairy,Steel
4,0.557625,0.605417,0.229444,0.320429,0.854406,0.449371,0.685158,3,Fairy,Ground
5,0.310299,0.43018,0.406046,0.673561,0.749702,0.30115,0.75978,5,Fighting,Flying
6,0.450337,0.509126,0.450063,0.474237,0.660819,0.332634,0.58138,3,Rock,
7,0.416506,0.475272,0.42543,0.585823,0.633514,0.38451,0.504382,2,Poison,Bug
8,0.168342,0.441453,0.424031,0.707217,0.663815,0.167093,0.513924,2,Bug,
9,0.337369,0.462737,0.326222,0.491315,0.74619,0.403332,0.574154,2,Fairy,Bug


In [34]:
readable_df[numerical_cols] = scaler.inverse_transform(readable_df[numerical_cols])


In [35]:
readable_df.head(10)

Unnamed: 0,HP,Attack,Defense,Sp. Atk,Sp. Def,Speed,Legendary,Generation,Type 1,Type 2
0,79.621445,91.108345,89.945808,93.133186,92.565559,76.547234,True,2,Fairy,Steel
1,76.025627,97.464424,85.308098,88.90818,92.003723,78.886414,True,2,Fairy,Flying
2,78.814262,91.776497,84.34124,93.034286,93.663254,76.366486,True,5,Rock,
3,78.61441,93.777794,94.905655,91.685143,93.893562,73.339447,True,2,Fairy,Steel
4,83.488609,98.639191,80.992882,83.298615,95.664833,81.328278,True,3,Fairy,Ground
5,77.177185,92.955017,86.496506,94.846695,92.752853,77.023582,True,5,Fighting,Flying
6,80.750771,95.515808,87.868248,88.328407,90.280884,77.937943,True,3,Rock,
7,79.887436,94.417671,87.100609,91.977486,89.521484,79.444565,True,2,Poison,Bug
8,73.554619,93.320702,87.057014,95.947304,90.364212,73.130257,True,2,Bug,
9,77.867966,94.011101,84.008873,88.886909,92.655182,79.991188,True,2,Fairy,Bug


Evaluate discriminator accruacy

In [36]:
gan.evaluate_discriminator(generator, discriminator)

AttributeError: 'GAN' object has no attribute 'evaluate_discriminator'

Revert the values to human readable

In [None]:
#TODO: Make data human readble+
