In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline


In [5]:
df = pd.read_csv('adult.csv')

numerical_cols = df.select_dtypes(include=['float64', 'int64']).columns
categorical_cols = df.select_dtypes(include=['object']).columns

# Preprocessing for numerical data
numerical_transformer = StandardScaler()

# Preprocessing for categorical data
categorical_transformer = OneHotEncoder(handle_unknown='ignore')

preprocessor = ColumnTransformer(
    transformers=[
        ('num', numerical_transformer, numerical_cols),
        ('cat', categorical_transformer, categorical_cols)
    ])

# Fit and transform the data
data_preprocessed = preprocessor.fit_transform(df)


In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, embed_dim),
        )
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        x = self.norm1(x + self.dropout(attn_output))
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))
        return x


In [7]:
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim, embed_dim, num_heads, ff_hidden_dim, num_layers):
        super(Generator, self).__init__()
        self.linear = nn.Linear(input_dim, embed_dim)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, ff_hidden_dim) for _ in range(num_layers)
        ])
        self.output_layer = nn.Linear(embed_dim, output_dim)

    def forward(self, x):
        x = self.linear(x).unsqueeze(1)  # Add sequence dimension
        for block in self.transformer_blocks:
            x = block(x)
        x = self.output_layer(x.squeeze(1))  # Remove sequence dimension
        return x


In [8]:
class Discriminator(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads, ff_hidden_dim, num_layers):
        super(Discriminator, self).__init__()
        self.linear = nn.Linear(input_dim, embed_dim)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, ff_hidden_dim) for _ in range(num_layers)
        ])
        self.output_layer = nn.Linear(embed_dim, 1)

    def forward(self, x):
        x = self.linear(x).unsqueeze(1)  # Add sequence dimension
        for block in self.transformer_blocks:
            x = block(x)
        x = self.output_layer(x.squeeze(1))  # Remove sequence dimension
        return torch.sigmoid(x)


In [10]:
from tqdm import tqdm
# Hyperparameters
input_dim = data_preprocessed.shape[1]
latent_dim = 100
embed_dim = 128
num_heads = 8
ff_hidden_dim = 256
num_layers = 3
batch_size = 64
epochs = 100

# Models
generator = Generator(latent_dim, input_dim, embed_dim, num_heads, ff_hidden_dim, num_layers)
discriminator = Discriminator(input_dim, embed_dim, num_heads, ff_hidden_dim, num_layers)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

# Loss function
criterion = nn.BCELoss()

# Training loop
for epoch in range(epochs):
    for _ in tqdm(range(data_preprocessed.shape[0] // batch_size)):
        # Train Discriminator
        idx = np.random.randint(0, data_preprocessed.shape[0], batch_size)
        real_data = torch.tensor(data_preprocessed[idx].toarray(), dtype=torch.float32)
        real_labels = torch.ones((batch_size, 1))
        fake_labels = torch.zeros((batch_size, 1))

        latent_space_samples = torch.randn((batch_size, latent_dim))
        generated_data = generator(latent_space_samples)

        real_predictions = discriminator(real_data)
        fake_predictions = discriminator(generated_data.detach())

        d_loss_real = criterion(real_predictions, real_labels)
        d_loss_fake = criterion(fake_predictions, fake_labels)
        d_loss = d_loss_real + d_loss_fake

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        generated_data = generator(latent_space_samples)
        fake_predictions = discriminator(generated_data)

        g_loss = criterion(fake_predictions, real_labels)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

    print(f'Epoch {epoch+1}/{epochs}, D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')


100%|██████████| 508/508 [00:26<00:00, 18.87it/s]


Epoch 1/100, D Loss: 0.002641930477693677, G Loss: 7.150287628173828


100%|██████████| 508/508 [00:26<00:00, 18.92it/s]


Epoch 2/100, D Loss: 0.0013189775636419654, G Loss: 7.834990501403809


100%|██████████| 508/508 [00:27<00:00, 18.46it/s]


Epoch 3/100, D Loss: 0.0015773983905091882, G Loss: 8.198347091674805


100%|██████████| 508/508 [00:27<00:00, 18.27it/s]


Epoch 4/100, D Loss: 0.0005443186964839697, G Loss: 8.549238204956055


100%|██████████| 508/508 [00:27<00:00, 18.17it/s]


Epoch 5/100, D Loss: 0.0005815893528051674, G Loss: 8.61555290222168


100%|██████████| 508/508 [00:27<00:00, 18.36it/s]


Epoch 6/100, D Loss: 0.0010563037358224392, G Loss: 7.343108654022217


100%|██████████| 508/508 [00:27<00:00, 18.69it/s]


Epoch 7/100, D Loss: 0.0007379487506113946, G Loss: 8.441676139831543


100%|██████████| 508/508 [00:27<00:00, 18.52it/s]


Epoch 8/100, D Loss: 0.0003854089882224798, G Loss: 8.654929161071777


100%|██████████| 508/508 [00:27<00:00, 18.55it/s]


Epoch 9/100, D Loss: 0.0003135347506031394, G Loss: 8.899312973022461


100%|██████████| 508/508 [00:27<00:00, 18.57it/s]


Epoch 10/100, D Loss: 0.00026678681024350226, G Loss: 8.949024200439453


100%|██████████| 508/508 [00:27<00:00, 18.71it/s]


Epoch 11/100, D Loss: 0.00020740322361234576, G Loss: 9.417472839355469


100%|██████████| 508/508 [00:27<00:00, 18.62it/s]


Epoch 12/100, D Loss: 0.00017741319607011974, G Loss: 9.675848960876465


100%|██████████| 508/508 [00:27<00:00, 18.45it/s]


Epoch 13/100, D Loss: 0.00013662023411598057, G Loss: 9.887351989746094


100%|██████████| 508/508 [00:27<00:00, 18.57it/s]


Epoch 14/100, D Loss: 0.00010309355275239795, G Loss: 10.094850540161133


100%|██████████| 508/508 [00:27<00:00, 18.59it/s]


Epoch 15/100, D Loss: 8.163384336512536e-05, G Loss: 10.34306526184082


100%|██████████| 508/508 [00:27<00:00, 18.74it/s]


Epoch 16/100, D Loss: 6.654774188064039e-05, G Loss: 10.505105018615723


100%|██████████| 508/508 [00:27<00:00, 18.69it/s]


Epoch 17/100, D Loss: 0.00021635540178976953, G Loss: 9.230762481689453


100%|██████████| 508/508 [00:27<00:00, 18.81it/s]


Epoch 18/100, D Loss: 0.00019116007024422288, G Loss: 9.45492172241211


100%|██████████| 508/508 [00:27<00:00, 18.67it/s]


Epoch 19/100, D Loss: 0.00016594695625826716, G Loss: 9.562856674194336


100%|██████████| 508/508 [00:27<00:00, 18.47it/s]


Epoch 20/100, D Loss: 0.00013978505739942193, G Loss: 9.765186309814453


100%|██████████| 508/508 [00:27<00:00, 18.53it/s]


Epoch 21/100, D Loss: 0.00011858678772114217, G Loss: 9.89698600769043


100%|██████████| 508/508 [00:27<00:00, 18.34it/s]


Epoch 22/100, D Loss: 9.991079423343763e-05, G Loss: 10.05087661743164


100%|██████████| 508/508 [00:27<00:00, 18.61it/s]


Epoch 23/100, D Loss: 0.00010583228140603751, G Loss: 10.049247741699219


100%|██████████| 508/508 [00:27<00:00, 18.25it/s]


Epoch 24/100, D Loss: 0.00012503695324994624, G Loss: 9.88905143737793


100%|██████████| 508/508 [00:27<00:00, 18.15it/s]


Epoch 25/100, D Loss: 0.0001392374251736328, G Loss: 10.088943481445312


100%|██████████| 508/508 [00:27<00:00, 18.63it/s]


Epoch 26/100, D Loss: 0.00013072371075395495, G Loss: 10.25478458404541


100%|██████████| 508/508 [00:27<00:00, 18.80it/s]


Epoch 27/100, D Loss: 9.746754949446768e-05, G Loss: 10.39176082611084


100%|██████████| 508/508 [00:27<00:00, 18.47it/s]


Epoch 28/100, D Loss: 0.00010999138612532988, G Loss: 10.544368743896484


100%|██████████| 508/508 [00:27<00:00, 18.72it/s]


Epoch 29/100, D Loss: 7.069276762194932e-05, G Loss: 10.729772567749023


100%|██████████| 508/508 [00:27<00:00, 18.49it/s]


Epoch 30/100, D Loss: 0.00014066434232518077, G Loss: 10.116379737854004


100%|██████████| 508/508 [00:27<00:00, 18.63it/s]


Epoch 31/100, D Loss: 0.00011005566921085119, G Loss: 10.277688980102539


100%|██████████| 508/508 [00:27<00:00, 18.16it/s]


Epoch 32/100, D Loss: 0.0001688917400315404, G Loss: 10.111069679260254


100%|██████████| 508/508 [00:27<00:00, 18.59it/s]


Epoch 33/100, D Loss: 0.00011282786726951599, G Loss: 10.236989974975586


100%|██████████| 508/508 [00:27<00:00, 18.34it/s]


Epoch 34/100, D Loss: 9.37360746320337e-05, G Loss: 10.329834938049316


100%|██████████| 508/508 [00:26<00:00, 19.00it/s]


Epoch 35/100, D Loss: 8.049399184528738e-05, G Loss: 10.423408508300781


100%|██████████| 508/508 [00:27<00:00, 18.79it/s]


Epoch 36/100, D Loss: 6.92237590556033e-05, G Loss: 10.514349937438965


100%|██████████| 508/508 [00:27<00:00, 18.74it/s]


Epoch 37/100, D Loss: 0.0002542842412367463, G Loss: 9.016712188720703


100%|██████████| 508/508 [00:26<00:00, 18.86it/s]


Epoch 38/100, D Loss: 0.00018958598957397044, G Loss: 9.502943992614746


100%|██████████| 508/508 [00:27<00:00, 18.70it/s]


Epoch 39/100, D Loss: 0.000166910671396181, G Loss: 9.755376815795898


100%|██████████| 508/508 [00:27<00:00, 18.34it/s]


Epoch 40/100, D Loss: 0.0008782775839790702, G Loss: 8.451332092285156


100%|██████████| 508/508 [00:27<00:00, 18.69it/s]


Epoch 41/100, D Loss: 0.0005935446824878454, G Loss: 8.771256446838379


100%|██████████| 508/508 [00:27<00:00, 18.77it/s]


Epoch 42/100, D Loss: 0.0003973627754021436, G Loss: 8.962388038635254


100%|██████████| 508/508 [00:27<00:00, 18.66it/s]


Epoch 43/100, D Loss: 0.0009618434123694897, G Loss: 8.668275833129883


100%|██████████| 508/508 [00:27<00:00, 18.53it/s]


Epoch 44/100, D Loss: 0.000580674852244556, G Loss: 8.924766540527344


100%|██████████| 508/508 [00:26<00:00, 18.86it/s]


Epoch 45/100, D Loss: 0.0003520448808558285, G Loss: 9.111336708068848


100%|██████████| 508/508 [00:27<00:00, 18.67it/s]


Epoch 46/100, D Loss: 0.00033186201471835375, G Loss: 9.304418563842773


100%|██████████| 508/508 [00:27<00:00, 18.74it/s]


Epoch 47/100, D Loss: 0.0002355455217184499, G Loss: 9.510485649108887


100%|██████████| 508/508 [00:27<00:00, 18.42it/s]


Epoch 48/100, D Loss: 0.00018388066382613033, G Loss: 9.71374797821045


100%|██████████| 508/508 [00:27<00:00, 18.29it/s]


Epoch 49/100, D Loss: 0.0001476233301218599, G Loss: 9.929492950439453


100%|██████████| 508/508 [00:27<00:00, 18.72it/s]


Epoch 50/100, D Loss: 0.00011753465514630079, G Loss: 10.147013664245605


100%|██████████| 508/508 [00:27<00:00, 18.30it/s]


Epoch 51/100, D Loss: 9.772906196303666e-05, G Loss: 10.256328582763672


100%|██████████| 508/508 [00:27<00:00, 18.47it/s]


Epoch 52/100, D Loss: 0.0003890537773258984, G Loss: 9.866067886352539


100%|██████████| 508/508 [00:27<00:00, 18.66it/s]


Epoch 53/100, D Loss: 0.00028700937400572, G Loss: 9.553637504577637


100%|██████████| 508/508 [00:27<00:00, 18.40it/s]


Epoch 54/100, D Loss: 0.00020859797950834036, G Loss: 9.751860618591309


100%|██████████| 508/508 [00:27<00:00, 18.77it/s]


Epoch 55/100, D Loss: 0.00017985918384511024, G Loss: 9.873023986816406


100%|██████████| 508/508 [00:27<00:00, 18.68it/s]


Epoch 56/100, D Loss: 0.00014630115765612572, G Loss: 10.003334045410156


100%|██████████| 508/508 [00:27<00:00, 18.65it/s]


Epoch 57/100, D Loss: 0.0001180565741378814, G Loss: 10.120864868164062


100%|██████████| 508/508 [00:27<00:00, 18.60it/s]


Epoch 58/100, D Loss: 0.00011085668666055426, G Loss: 10.267000198364258


100%|██████████| 508/508 [00:27<00:00, 18.50it/s]


Epoch 59/100, D Loss: 8.296796295326203e-05, G Loss: 10.426833152770996


100%|██████████| 508/508 [00:26<00:00, 18.90it/s]


Epoch 60/100, D Loss: 7.12763867340982e-05, G Loss: 10.57806396484375


100%|██████████| 508/508 [00:26<00:00, 19.22it/s]


Epoch 61/100, D Loss: 5.6659009715076536e-05, G Loss: 10.767629623413086


100%|██████████| 508/508 [00:26<00:00, 19.17it/s]


Epoch 62/100, D Loss: 0.00023709048400633037, G Loss: 10.186620712280273


100%|██████████| 508/508 [00:26<00:00, 19.00it/s]


Epoch 63/100, D Loss: 0.00012048958888044581, G Loss: 10.281296730041504


100%|██████████| 508/508 [00:26<00:00, 19.02it/s]


Epoch 64/100, D Loss: 0.00010376227146480232, G Loss: 10.357683181762695


100%|██████████| 508/508 [00:26<00:00, 19.29it/s]


Epoch 65/100, D Loss: 8.896077633835375e-05, G Loss: 10.134119033813477


100%|██████████| 508/508 [00:26<00:00, 18.86it/s]


Epoch 66/100, D Loss: 0.0012326936703175306, G Loss: 10.0010986328125


100%|██████████| 508/508 [00:26<00:00, 18.87it/s]


Epoch 67/100, D Loss: 0.0001452096039429307, G Loss: 10.1367826461792


100%|██████████| 508/508 [00:27<00:00, 18.80it/s]


Epoch 68/100, D Loss: 0.00010579839727142826, G Loss: 10.215027809143066


100%|██████████| 508/508 [00:27<00:00, 18.55it/s]


Epoch 69/100, D Loss: 9.154168947134167e-05, G Loss: 10.356267929077148


100%|██████████| 508/508 [00:27<00:00, 18.79it/s]


Epoch 70/100, D Loss: 7.922885561129078e-05, G Loss: 10.438141822814941


100%|██████████| 508/508 [00:26<00:00, 18.86it/s]


Epoch 71/100, D Loss: 6.730313180014491e-05, G Loss: 10.615350723266602


100%|██████████| 508/508 [00:26<00:00, 18.91it/s]


Epoch 72/100, D Loss: 5.777280603069812e-05, G Loss: 10.755666732788086


100%|██████████| 508/508 [00:26<00:00, 18.91it/s]


Epoch 73/100, D Loss: 4.7870482376310974e-05, G Loss: 10.926957130432129


100%|██████████| 508/508 [00:26<00:00, 19.16it/s]


Epoch 74/100, D Loss: 0.00010926574759650975, G Loss: 10.136486053466797


100%|██████████| 508/508 [00:26<00:00, 19.19it/s]


Epoch 75/100, D Loss: 0.00010281417053192854, G Loss: 10.197491645812988


100%|██████████| 508/508 [00:26<00:00, 19.07it/s]


Epoch 76/100, D Loss: 9.756229701451957e-05, G Loss: 10.233102798461914


100%|██████████| 508/508 [00:26<00:00, 18.87it/s]


Epoch 77/100, D Loss: 8.98124126251787e-05, G Loss: 10.265610694885254


100%|██████████| 508/508 [00:26<00:00, 19.05it/s]


Epoch 78/100, D Loss: 0.00017441748059354722, G Loss: 9.613821983337402


100%|██████████| 508/508 [00:26<00:00, 18.85it/s]


Epoch 79/100, D Loss: 0.00015866299509070814, G Loss: 9.697870254516602


100%|██████████| 508/508 [00:26<00:00, 19.20it/s]


Epoch 80/100, D Loss: 0.00014597430708818138, G Loss: 9.784790992736816


100%|██████████| 508/508 [00:26<00:00, 19.32it/s]


Epoch 81/100, D Loss: 0.00013223590212874115, G Loss: 9.89561653137207


100%|██████████| 508/508 [00:26<00:00, 19.22it/s]


Epoch 82/100, D Loss: 0.00011704268399626017, G Loss: 10.040275573730469


100%|██████████| 508/508 [00:26<00:00, 19.01it/s]


Epoch 83/100, D Loss: 0.0001019856717903167, G Loss: 10.186015129089355


100%|██████████| 508/508 [00:27<00:00, 18.77it/s]


Epoch 84/100, D Loss: 8.691697439644486e-05, G Loss: 10.331685066223145


100%|██████████| 508/508 [00:26<00:00, 19.28it/s]


Epoch 85/100, D Loss: 7.361043390119448e-05, G Loss: 10.520736694335938


100%|██████████| 508/508 [00:26<00:00, 19.16it/s]


Epoch 86/100, D Loss: 6.153564754640684e-05, G Loss: 10.689407348632812


100%|██████████| 508/508 [00:26<00:00, 19.03it/s]


Epoch 87/100, D Loss: 0.00018569425446912646, G Loss: 9.553768157958984


100%|██████████| 508/508 [00:26<00:00, 19.20it/s]


Epoch 88/100, D Loss: 0.00017040976672433317, G Loss: 9.657069206237793


100%|██████████| 508/508 [00:26<00:00, 19.04it/s]


Epoch 89/100, D Loss: 0.00019865116337314248, G Loss: 9.6639404296875


100%|██████████| 508/508 [00:26<00:00, 19.17it/s]


Epoch 90/100, D Loss: 0.00015179063484538347, G Loss: 9.801020622253418


100%|██████████| 508/508 [00:26<00:00, 19.28it/s]


Epoch 91/100, D Loss: 0.00013625805149786174, G Loss: 9.862372398376465


100%|██████████| 508/508 [00:26<00:00, 19.12it/s]


Epoch 92/100, D Loss: 0.00017161561117973179, G Loss: 9.989849090576172


100%|██████████| 508/508 [00:27<00:00, 18.80it/s]


Epoch 93/100, D Loss: 0.0001421076594851911, G Loss: 9.815397262573242


100%|██████████| 508/508 [00:26<00:00, 18.99it/s]


Epoch 94/100, D Loss: 0.00012066504859831184, G Loss: 9.931624412536621


100%|██████████| 508/508 [00:26<00:00, 19.17it/s]


Epoch 95/100, D Loss: 0.00010365669731982052, G Loss: 10.102945327758789


100%|██████████| 508/508 [00:27<00:00, 18.61it/s]


Epoch 96/100, D Loss: 8.907502342481166e-05, G Loss: 10.2883882522583


100%|██████████| 508/508 [00:26<00:00, 19.02it/s]


Epoch 97/100, D Loss: 7.553557225037366e-05, G Loss: 10.463366508483887


100%|██████████| 508/508 [00:26<00:00, 18.88it/s]


Epoch 98/100, D Loss: 0.00012635687016882002, G Loss: 10.019869804382324


100%|██████████| 508/508 [00:27<00:00, 18.73it/s]


Epoch 99/100, D Loss: 0.00011392342275939882, G Loss: 10.128747940063477


100%|██████████| 508/508 [00:26<00:00, 19.16it/s]

Epoch 100/100, D Loss: 0.00010391234536655247, G Loss: 10.192365646362305





In [15]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, OneHotEncoder
import torch

# Need to redo this bit.. for inverse transformation
df = pd.read_csv('adult.csv')

numerical_cols = df.select_dtypes(include=['float64', 'int64']).columns
categorical_cols = df.select_dtypes(include=['object']).columns

numerical_transformer = StandardScaler()
numerical_transformer.fit(df[numerical_cols])
df[numerical_cols] = numerical_transformer.transform(df[numerical_cols])

# Preprocessing for categorical data
categorical_transformer = OneHotEncoder(handle_unknown='ignore', sparse=False)
encoded_categorical = categorical_transformer.fit_transform(df[categorical_cols])
encoded_categorical_df = pd.DataFrame(encoded_categorical, columns=categorical_transformer.get_feature_names_out(categorical_cols))

# Combine preprocessed numerical and categorical data
data_preprocessed = pd.concat([df[numerical_cols].reset_index(drop=True), encoded_categorical_df.reset_index(drop=True)], axis=1)
data_preprocessed_np = data_preprocessed.values

latent_dim = 100  # This should match the latent_dim used in your generator
latent_space_samples = torch.randn((data_preprocessed_np.shape[0], latent_dim))
synthetic_data = generator(latent_space_samples).detach().numpy()

# Inverse transform numerical features
synthetic_numerical_data = numerical_transformer.inverse_transform(synthetic_data[:, :len(numerical_cols)])



In [16]:
# Get the synthetic categorical data
synthetic_categorical_data = synthetic_data[:, len(numerical_cols):]

# Inverse transform categorical features
inverse_transformed_categorical_data = categorical_transformer.inverse_transform(synthetic_categorical_data)

# Create DataFrames from the inverse transformed data
synthetic_numerical_df = pd.DataFrame(synthetic_numerical_data, columns=numerical_cols)
synthetic_categorical_df = pd.DataFrame(inverse_transformed_categorical_data, columns=categorical_cols)

# Concatenate numerical and categorical data
synthetic_df = pd.concat([synthetic_numerical_df, synthetic_categorical_df], axis=1)


In [18]:
synthetic_df # nonsense 🥲 (age <0)

Unnamed: 0,age,fnlwgt,education.num,capital.gain,capital.loss,hours.per.week,workclass,education,marital.status,occupation,relationship,race,sex,native.country,income
0,-35.414944,-287215.500000,5.968259,-5537.321289,274.682770,-18.079611,Private,12th,Never-married,Other-service,Own-child,Asian-Pac-Islander,Female,United-States,<=50K
1,67.192314,659716.625000,21.465012,13562.478516,-845.944641,24.585440,?,12th,Married-spouse-absent,Prof-specialty,Not-in-family,Amer-Indian-Eskimo,Female,United-States,>50K
2,-37.984989,-151044.859375,7.302138,-7761.911621,381.110291,-25.416937,Private,12th,Never-married,Other-service,Own-child,Amer-Indian-Eskimo,Female,United-States,<=50K
3,5.614293,-216763.046875,3.854319,-2927.934326,1186.068848,19.353485,Private,11th,Married-civ-spouse,Other-service,Husband,White,Male,United-States,>50K
4,-31.970943,-227197.015625,6.109797,-10172.156250,513.503723,-21.532618,Private,12th,Never-married,Other-service,Own-child,Amer-Indian-Eskimo,Female,United-States,<=50K
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
32556,-35.840191,-239210.703125,5.383281,-8209.619141,451.426239,-23.535036,Private,12th,Never-married,Other-service,Own-child,Asian-Pac-Islander,Female,United-States,<=50K
32557,-42.788166,-171864.015625,5.860575,-8411.810547,311.722992,-25.727713,Private,12th,Never-married,Other-service,Own-child,Asian-Pac-Islander,Female,United-States,<=50K
32558,72.623451,801977.250000,21.129896,15454.623047,-842.971191,29.396048,Self-emp-not-inc,12th,Married-spouse-absent,Prof-specialty,Not-in-family,Asian-Pac-Islander,Female,United-States,>50K
32559,76.364120,736179.937500,21.841326,18654.048828,-966.969482,29.968460,?,12th,Divorced,Adm-clerical,Not-in-family,Asian-Pac-Islander,Female,United-States,>50K


In [19]:
# Saving the complete models
generator_path = 'generator_complete.pth'
discriminator_path = 'discriminator_complete.pth'

torch.save(generator, generator_path)
torch.save(discriminator, discriminator_path)

# Saving the model weights
generator_weights_path = 'generator_weights.pth'
discriminator_weights_path = 'discriminator_weights.pth'

torch.save(generator.state_dict(), generator_weights_path)
torch.save(discriminator.state_dict(), discriminator_weights_path)

In [20]:
!ls

adult.csv		    discriminator_weights.pth  generator_weights.pth
discriminator_complete.pth  generator_complete.pth     sample_data


In [None]:
# # Loading the complete models
# generator = torch.load('generator_complete.pth')
# discriminator = torch.load('discriminator_complete.pth')

# # Loading the model weights
# # Redefine the models
# generator = Generator(latent_dim, input_dim, embed_dim, num_heads, ff_hidden_dim, num_layers)
# discriminator = Discriminator(input_dim, embed_dim, num_heads, ff_hidden_dim, num_layers)

# generator.load_state_dict(torch.load('generator_weights.pth'))
# discriminator.load_state_dict(torch.load('discriminator_weights.pth'))