In [195]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd

In [196]:
def current(B, arrJ, arrC, y): # y is initial phase difference of the whole circuit, B is the magnetic field, arrJ is the location of junctions, arrC is critical current associated with each junction
    curr = 0 # summation of all currents in the entire junction
    limit = int(len(arrJ) / 2) # number of junctions in the SQUID
    numOfSegments = 0
    for n in range(limit):
        if arrJ[2 * n + 1] - arrJ[2 * n] < 0.05:
            numOfSegments = 5
        else:
            numOfSegments = int(100 * (arrJ[2 * n + 1] - arrJ[2 * n]).item())
        sizeOfSegment = float((arrJ[2 * n + 1] - arrJ[2 * n]).item() / numOfSegments)
        for i in range(numOfSegments):
            curr += arrC[n] * np.sin(y + (2 * np.pi * B) * (arrJ[2 * n].item() + i * sizeOfSegment)) * (1 / numOfSegments)

    return curr

In [197]:
def maxCurrent(B, arrayJ, arrayC): # Spits out the maximum current by varying the gauge invariant phase of the left end (free parameter) gamma
    Y=np.linspace(0, 2*np.pi, 150)
    dummyArray=[]
    for gamma in Y:
        dummyArray.append(current(B, arrayJ, arrayC, gamma))
    return max(dummyArray)

def criticalCurrent(density, arrJ):
    criticalCurrents = []
    junctionWidths = []
    for i in range(len(arrJ)//2):
        junctionWidths.append(arrJ[2*i+1] - arrJ[2*i])
    for i in range (len(junctionWidths)):
        criticalCurrents.append(junctionWidths[i] * density[i])
    return criticalCurrents, junctionWidths

In [198]:
def enforce_constraints(array):
    array = torch.clamp(array, 0, 1)  # Clamp values between 0 and 1
    array[0] = 0  # Ensure first value is 0
    array[-1] = 1  # Ensure last value is 1
    array, _ = torch.sort(array)  # Sort the array
    return array

In [199]:
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [206]:
def cost_function(gen_data_tensor, model_data_tensor):
    diff = gen_data_tensor - model_data_tensor
    squared_diff = torch.pow(diff, 2)  # Squaring the differences
    sum_squared_diff = torch.sum(squared_diff)  # Summing the squared differences
    return sum_squared_diff

def calculate_critical_current(arrJ):
    criticalCurrents, junctionWidths = criticalCurrent(np.ones(len(arrJ)//2), arrJ)
    MagField = np.linspace(-10, 10, 10000)
    IMaxPoint = []
    for B in MagField:
        IMaxPoint.append(maxCurrent(B, arrJ, criticalCurrents) / np.sum(criticalCurrents))
    return IMaxPoint

In [201]:
experimental_model = pd.read_csv(r"experimental data\artifical data\remodified 5.0uV.csv")
experimental_data = np.array(experimental_model['I_c'])
input_size = 200
hidden_size = 200
output_size = 200 
generator = Generator(input_size, hidden_size, output_size)

In [202]:
optimizer = optim.Adam(generator.parameters(), lr=0.001)

In [None]:
num_epochs = 1000

for epoch in range(num_epochs):
    input_data = torch.randn(1, input_size)
    
    generated_array = torch.randn(1, input_size)[0]

    generated_array_1d = generated_array.flatten()
    
    generated_array_numpy = enforce_constraints(generated_array_1d).detach().numpy()
    print(generated_array_numpy)
    
    predicted_graph = calculate_critical_current(generated_array_numpy)  # Calculate graph from generated array
    experimental_data_tensor = torch.tensor(experimental_data, dtype=torch.float32)
    predicted_graph_tensor = torch.tensor(predicted_graph, dtype=torch.float32)
    experimental_data_tensor.requires_grad = True
    predicted_graph_tensor.requires_grad = True

    optimizer.zero_grad()

    loss = cost_function(gen_data_tensor=experimental_data_tensor, model_data_tensor=predicted_graph_tensor)

    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")