In [1]:
#load libraries
import torch, math, copy
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torchvision
from torchvision import datasets, transforms
import torch.nn as nn
import pandas as pd
import torch.nn.init as init
import torch.nn.functional as F
from scipy.stats import kde
import matplotlib.pyplot as plt
from tqdm import tqdm
import random

In [2]:
#basic function to tell whether a number is prime or not
def isPrime(n):
    key = 0
    for j in range(2, n):
        if n % j == 0:
            key += 1
    return key == 0

In [3]:
#generating training data
data = []
for x in tqdm(range(2,100000)):
    if isPrime(x):
        data.append([x])

100%|████████████████████████████████████| 99998/99998 [01:54<00:00, 871.27it/s]


In [6]:
#sample from training data
def sample(n):
    s = random.sample(data, n)
    return s

In [12]:
def dataset(n):
    df = torch.tensor([sample(n)]).float()
    return df

In [24]:
#implementation of GAN network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(1, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU(),
            nn.Linear(50, 1)
        )

    def decode(self, input):
        out = self.network(input)
        return out

    def forward(self, n):
        z = torch.randint(1, 1100, (1,)).float()
        samples = self.decode(z)
        samples = samples.unsqueeze(0)
        for i in range(n-1):
            z = torch.randint(1, 1100, (1,)).float()
            z = self.decode(z)
            z = z.unsqueeze(0)
            samples = torch.cat((samples, z), 0)
        return (samples.unsqueeze(0)).abs().round()

In [25]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.network = nn.Sequential(
            nn.Linear(1, 50),
            nn.ReLU(),
            nn.Linear(50, 50),
            nn.ReLU(),
            nn.Linear(50, 1)
        )

    def forward(self, input):
        out = self.network(input)
        out = nn.Sigmoid()(out)
        return out

In [26]:
generator = Generator()
gopt = torch.optim.Adam(generator.parameters(), lr=5e-4, betas=(0.5, 0.999))
discriminator = Discriminator()
dopt = torch.optim.Adam(discriminator.parameters(), lr=5e-4, betas=(0.5, 0.999))
criterion = torch.nn.BCEWithLogitsLoss()

Experiments and results:

In [30]:
for i in tqdm(range(120000)):
    
    # Generate real data and labels
    s = dataset(128)
    real_labels = torch.ones((1, 128, 1))
    fake_labels = torch.zeros((1, 128, 1))

    # Generate fake data
    fake_data = generator(128)
    
    #Train G
    gopt.zero_grad()
    g_fake_output = discriminator(fake_data)
    g_loss = criterion(g_fake_output, real_labels)
    g_loss.backward()
    gopt.step()
    
    #Train D
    
    dopt.zero_grad()
    d_real_output = discriminator(s)
    d_real_loss = criterion(d_real_output, real_labels)
    d_fake_output = discriminator(fake_data.detach()) 
    d_fake_loss = criterion(d_fake_output, fake_labels)
    d_loss = d_real_loss + d_fake_loss
    d_loss.backward()
    dopt.step()
    

    if i == 119999:
        checkpoint = generator(200)
        error = 0
        for x in checkpoint.view(-1):
            if isPrime(int(x)) == False:
                error += 1
        print('Accuracy of generator: ', 1 - error/200)

100%|██████████████████████████████████| 120000/120000 [16:41<00:00, 119.81it/s]

Accuracy of generator:  0.28500000000000003





In [31]:
torch.save(generator.state_dict(), 'my_model.pth')

In [61]:
# Load the state dictionary
generator.load_state_dict(torch.load('my_model.pth'))

# Set the model to evaluation mode
generator.eval()
checkpoint = generator(200)
error = 0
for x in checkpoint.view(-1):
    if isPrime(int(x)) == False:
        error += 1
print('Accuracy of generator: ', 1 - error/200)

Accuracy of generator:  0.36
