In [107]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from src import models,utils
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

class ConvRNN(nn.Module):
    def __init__(self, rnn_hidden_size, num_classes):
        super(ConvRNN, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        rnn_input_size = self.cnn_output_size((1, 28, 28))
        #self.rnn = nn.RNN(rnn_input_size, rnn_hidden_size, batch_first=True)
        
        self.rnn = models.CTRNN(rnn_input_size, rnn_hidden_size, device, dt=5, constraint="spec")
        self.fc = nn.Linear(rnn_hidden_size, num_classes)
        
    def cnn_output_size(self, input_size):
        x = torch.randn(1, *input_size)
        x = self.conv(x)
        return x.view(-1).size(0)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = x.unsqueeze(1).repeat(1, 50, 1)  # Repeat the tensor for 50 timesteps
        outputs, _ = self.rnn(x)
        x = outputs[:,-1,:]
        x = self.fc(x)
        return x


class LinearRNN(nn.Module):
    def __init__(self, input_size, rnn_hidden_size, num_classes):
        super(LinearRNN, self).__init__()

        #self.linear = nn.Linear(input_size, rnn_hidden_size)
        
        #self.rnn = nn.RNN(rnn_hidden_size, rnn_hidden_size, batch_first=True)
        
        self.rnn = models.CTRNN(input_size, rnn_hidden_size, device, dt=15, constraint= "sym")
        
        
        self.fc = nn.Linear(rnn_hidden_size, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        #x = self.linear(x)
        x = x.unsqueeze(1).repeat(1, 25, 1)
        outputs, _ = self.rnn(x)

        x = outputs[:,-1,:]
        x = self.fc(x)
        return x

input_size = 28*28
rnn_hidden_size = 64
num_classes = 10
model = LinearRNN(input_size,rnn_hidden_size, num_classes)


In [86]:
from scipy import linalg

def confirm_condition(A,M):
    
    eigs,_ = np.linalg.eig(A @ M + M @ A.T)
    
    e_real = np.real(eigs)
    
    if any(e_real > 0):               
        return 0
    else:
        return 1

def generate_random_sparse_contracting_weight_and_metric(n):    
    I = np.eye(n)

    g0 = 1/np.sqrt(n)           #initial gain for random matrix
    s0 = 0.9                    #initial sparsity
    gain = 0.99


    mask = np.random.choice([False,True], n**2, p = [s0,1-s0]).reshape(n,n)
    W0 = mask*np.random.normal(0, g0, size = (n,n))
    W0 -= np.diag(W0)

    A = np.abs(W0 - np.diag(W0)) - I

    M = linalg.solve_continuous_lyapunov(A.T, -I)   

    condition_met = confirm_condition(A,M)    


    while not condition_met:

        W0 *= gain

        A = np.abs(W0) - I

        M = linalg.solve_continuous_lyapunov(A.T, -I)   

        condition_met = confirm_condition(A,M) 
        
    return W0, M 


n = 200
W,M = generate_random_sparse_contracting_weight_and_metric(n)
W

array([[ 0.        ,  0.        , -0.        , ...,  0.        ,
        -0.        ,  0.0069099 ],
       [ 0.        ,  0.        ,  0.        , ..., -0.        ,
        -0.        ,  0.0069099 ],
       [ 0.00607269,  0.        ,  0.        , ..., -0.        ,
        -0.01851189,  0.0069099 ],
       ...,
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.0069099 ],
       [ 0.00324453,  0.        ,  0.        , ..., -0.        ,
         0.        , -0.00984114],
       [ 0.        ,  0.        , -0.        , ...,  0.        ,
        -0.        ,  0.        ]])

In [104]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ns = [64, 64, 64]
p = len(ns)
A_tril = torch.zeros((len(ns), len(ns)))
A_tril[-1, :] = 1
B_mask = utils.create_mask_given_A(A_tril, ns)


#generate intra_areal weights 

inter_areal_constraint = 'conformal'
intra_areal_constraint = 'sym'

W_bar, Ws = utils.create_random_block_stable_symmetric_weights(ns)


if inter_areal_constraint == "None":
    # No stability constraint on interareal weights
    B_mask = 0.5 * (B_mask + B_mask.T)
    M_bar = torch.eye(sum(ns), device=device)

if inter_areal_constraint == "conformal":
    # Stability constraint on interareal weights, conformal to the stability constraint on the subnetworks

    if intra_areal_constraint == "spectral":
        M_bar = torch.eye(sum(ns), device=device)

    if intra_areal_constraint == "sym" or intra_areal_constraint == "None":
        with torch.no_grad():
            Ms = utils.compute_metric_from_weights(
                Ws, ctype= intra_areal_constraint, device='cpu'
            )
            M_bar = torch.block_diag(*Ms)
            M_bar.to(device)
            
            W_bar.to(device)
            

In [119]:
def build_GWNET_random(input_size, 
                       ns, 
                       output_size,
                       device, 
                       gw_hidden_size=32, 
                       intra_areal_constraint = 'sym',
                       inter_areal_constraint="conformal"):
    
    p = len(ns)
    A_tril = torch.zeros((len(ns), len(ns)))
    A_tril[-1, :] = 1
    B_mask = utils.create_mask_given_A(A_tril, ns)


    #generate intra_areal weights 

    inter_areal_constraint = 'conformal'
    intra_areal_constraint = 'sym'

    W_bar, Ws = utils.create_random_block_stable_symmetric_weights(ns)


    if inter_areal_constraint == "None":
        # No stability constraint on interareal weights
        B_mask = 0.5 * (B_mask + B_mask.T)
        M_bar = torch.eye(sum(ns), device=device)

    if inter_areal_constraint == "conformal":
        # Stability constraint on interareal weights, conformal to the stability constraint on the subnetworks

        if intra_areal_constraint == "spectral":
            M_bar = torch.eye(sum(ns), device=device)

        if intra_areal_constraint == "sym" or intra_areal_constraint == "None":
            with torch.no_grad():
                Ms = utils.compute_metric_from_weights(
                    Ws, ctype= intra_areal_constraint, device='cpu'
                )
                M_bar = torch.block_diag(*Ms)
                M_bar.to(device)

                W_bar.to(device)

    B_mask = F.dropout(B_mask, 0.5)

    # Create global workspace weights and biases here
    tmp = nn.Linear(gw_hidden_size, gw_hidden_size, bias=True, device=device)
    b_bar = tmp.bias
    
    

    stacked_wb = {"rnn_h2h_weight": W_bar,
                  "rnn_h2h_bias": b_bar}

    B_mask = B_mask.to(device)

    net = models.GW_RNNNet(
        stacked_wb,
        input_size,
        ns,
        output_size,
        M_hat,
        B_mask,
        device,
        inter_areal_constraint,
        dt=30,
    )
    net.to(device)

    return net

In [120]:
net = build_GWNET_random(input_size = 10, 
                       ns = [64, 64, 64], 
                        output_size = 10,
                       device = device, 
                       gw_hidden_size=32, 
                       intra_areal_constraint = 'sym',
                       inter_areal_constraint="conformal")

TypeError: cat() received an invalid combination of arguments - got (Parameter, dim=int), but expected one of:
 * (tuple of Tensors tensors, int dim, *, Tensor out)
 * (tuple of Tensors tensors, name dim, *, Tensor out)


In [20]:
batch_size = 64

transform = transforms.Compose([transforms.Resize((28, 28)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

train_data = datasets.FashionMNIST(root='./data', train=True, download=False, transform=transform)
test_data = datasets.FashionMNIST(root='./data', train=False, download=False, transform=transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)


In [65]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
    

    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Accuracy of the model on the test dataset: {accuracy:.2f}%")


    #print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader):.6f}")

print("Finished Training")


Accuracy of the model on the test dataset: 74.54%
Accuracy of the model on the test dataset: 77.35%
Accuracy of the model on the test dataset: 80.26%
Accuracy of the model on the test dataset: 82.07%
Accuracy of the model on the test dataset: 81.80%


KeyboardInterrupt: 