In [4]:
import os
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, Dataset, DataLoader
from torch.optim.lr_scheduler import _LRScheduler
import torch.optim as optim

from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import MinMaxScaler

from typing import List, Tuple

In [2]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [14]:
def get_data_densematrix(dataset_size: int, m: int, n: int, *, 
                         seed: int, measurement_type: str = "mask",
                         include_X: bool = False, mask_prob: float = 0.33, 
                         rank_threshold: float = 0.1) -> Tuple[torch.Tensor, torch.Tensor, 
                                                               torch.Tensor, torch.Tensor]:
    torch.manual_seed(seed)
    
    # Generate matrix ranks
    rank = torch.randint(1, min(m, n) // 2, (dataset_size,))
    n_seen = round((1 - mask_prob) * m * n)
    rank = torch.minimum(rank, torch.tensor(n_seen // 2))
    
    # Generate low-rank matrices by multiplying m x r and r x n matrices
    A = torch.randn(dataset_size, m, rank.max().item())
    B = torch.randn(dataset_size, rank.max().item(), n)
    
    # Create a mask to zero out extra columns in A and rows in B
    mask_A = torch.arange(rank.max().item()).expand(dataset_size, -1) < rank.unsqueeze(1)
    mask_B = torch.arange(rank.max().item()).expand(dataset_size, -1) < rank.unsqueeze(1)
    
    A = A * mask_A.unsqueeze(1)
    B = B * mask_B.unsqueeze(2)
    
    X = torch.matmul(A, B)
    
    # Generate labels with shape (dataset_size, 1)
    Y = (rank < min(m, n) * rank_threshold).int().unsqueeze(1)
    
    if measurement_type == "mask":
        # Generate measurements of each X by masking some of the entries; dataset_size x m x n
        masks = torch.bernoulli(torch.full((dataset_size, m, n), 1 - mask_prob))
        Xhat = X * masks
    
    elif measurement_type == "trace":
        # Generate measurements of each X by taking Trace(X * O_i) for (2*maxrank) random matrices O_i
        O = torch.randn(dataset_size, 250, 50, m)
        # Function to compute the trace of each product 
        def trace_product(X, O):
            product = torch.matmul(O, X.unsqueeze(-1)).squeeze(-1)
            trace = torch.diagonal(product, dim1=-2, dim2=-1).sum(-1)
            return trace
        vectorized_trace_product = torch.vmap(
            torch.vmap(trace_product, in_dims=(None, 0)), 
            in_dims=(0, 0))
        # Compute the traces and products
        Xhat = vectorized_trace_product(X, O)
    
    else:
        raise ValueError("Measurement type not recognized. Choose from 'mask' or 'trace'.")
    
    if include_X:
        return X, Xhat, Y, rank
    else:
        return Xhat, Y, rank

In [15]:
# Test matrix data generation and visualize the data samples
dataset_size = 10
X, Xhat, Y, rank = get_data_densematrix(dataset_size, 25, 25, seed = 12,
                                        measurement_type="mask", include_X=True)
traces, Y, rank = get_data_densematrix(dataset_size, 25, 25, seed = 12,
                                       measurement_type="trace", include_X=False)
X.shape, Xhat.shape, traces.shape, Y.shape, rank.shape

(torch.Size([10, 25, 25]),
 torch.Size([10, 25, 25]),
 torch.Size([10, 250]),
 torch.Size([10, 1]),
 torch.Size([10]))

In [None]:
import matplotlib.pyplot as plt
fig, axs = plt.subplots(3, 4, figsize=(10, 10))
for i in range(4):
    axs[0, i].imshow(X[i], cmap="gray")
    axs[1, i].imshow(Xhat[i], cmap="gray")
    # Take traces vector and stack copies of the same trace to visualize
    axs[2, i].imshow(traces[i] * np.ones((traces.shape[-1], traces.shape[-1])), cmap="gray")
    axs[0, i].set_title(f"Ground truth X[{i}], r={np.linalg.matrix_rank(X[i])}")
    axs[1, i].set_title(f"Mask Xhat1[{i}]")
    axs[2, i].set_title(f"Traces Xhat2[{i}]'")
plt.suptitle("DGP - Ground truth and measurement samples")
plt.show()

In [None]:
class CustomDataset(Dataset):
    def __init__(self, Xhat, Y, rank, X=None):
        self.Xhat = Xhat
        self.Y = Y
        self.rank = rank
        self.X = X
        self.include_X = X is not None
        self.dataset_size = Xhat.shape[0]

    def __len__(self):
        return self.dataset_size

    def __getitem__(self, idx):
        if self.include_X:
            return self.X[idx], self.Xhat[idx], self.Y[idx], self.rank[idx]
        else:
            return self.Xhat[idx], self.Y[idx], self.rank[idx]

# Parameters for data generation
dataset_size = 1000
m = 250
n = 250
seed = 42
measurement_type = "mask"
include_X = True
mask_prob = 0.33
rank_threshold = 0.1

# Generate data
data = get_data_densematrix(dataset_size, m, n, 
                            seed=seed, 
                            measurement_type=measurement_type, 
                            include_X=include_X, 
                            mask_prob=mask_prob, 
                            rank_threshold=rank_threshold)

# Unpack generated data
if include_X:
    X, Xhat, Y, rank = data
else:
    Xhat, Y, rank = data
    X = None

# Create dataset
dataset = CustomDataset(Xhat, Y, rank, X)

# Create DataLoader
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Iterate over DataLoader
# for batch in dataloader:
#    print(batch)

In [31]:
# This cell is the mGRU from Kamesh Krishnamurthy
# h(t+1) = \sigma(Wh * h + Wx*x + b1) \odot \phi(Jh*h + Jx*x + b2)
class mGRU(nn.Module):
    def __init__(
        self,
        input_size,
        hidden_size,
        input_transform=None,
        output_transform=None,
        binaryoutput=False,
        device="cpu",
    ):
        super(mGRU, self).__init__()
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.Wz = nn.Linear(hidden_size + input_size, hidden_size)
        self.Wh = nn.Linear(input_size + hidden_size, hidden_size)
        torch.nn.init.xavier_normal_(self.Wz.weight)
        torch.nn.init.xavier_normal_(self.Wh.weight, gain=5.0 / 3.0)
        self.params = [self.Wh.weight, self.Wh.bias, 
                       self.Wz.weight, self.Wz.bias]
        self.input_transform = input_transform
        self.output_transform = output_transform

        # learnable initial hidden state
        self.initial_hidden_state = nn.Parameter(torch.zeros(1, hidden_size) * 0.05)
        self.device = device
        self.binaryoutput = binaryoutput
        
        if self.binaryoutput:
            self.fc = nn.Linear(hidden_size, 1)
            torch.nn.init.xavier_normal_(self.fc.weight)

    # hidden : batch_size x hidden_dim
    # x : batch_size x seq_length x input_dim
    def forward(self, x, hidden=None):
        # hidden should be batch_size x hidden_dim
        # x should be batch_size x seq_length x input_dim
        if self.input_transform is not None:
            x = self.input_transform(x)

        _batch_size = x.shape[0]
        _seq_len = x.shape[1]
        _input_dim = x.shape[2]
        _hidden_dim = self.hidden_size

        if hidden is None:
            # hidden = torch.zeros(1, _hidden_dim).repeat(_batch_size,1)
            hidden = self.get_initial_state(_batch_size)
        _output_seq = torch.zeros(_batch_size, _seq_len, _hidden_dim).to(self.device)

        for t in range(x.shape[1]):
            _ip = x[:, t, :]  # shape is batch_size x input_dim
            # ip_combined = torch.cat((hidden[0,:,:],_ip),1)  
            # # shape is batch_size x (input_dim + hidden_dim)
            ip_combined = torch.cat((hidden, _ip), -1)  
            # shape is batch_size x (input_dim + hidden_dim)
            self.z = self.sigmoid(self.Wz(ip_combined))
            hidden = torch.mul(self.z, self.tanh(self.Wh(ip_combined)))
            _output_seq[:, t, :] = hidden

        if self.output_transform is not None:
            _output_seq = self.output_transform(_output_seq)

        if self.binaryoutput:
            # Apply the final linear layer and sigmoid activation for binary prediction
            _output_seq = self.sigmoid(self.fc(_output_seq[:, -1, :]))
            return _output_seq.squeeze(), hidden
            
        return _output_seq, hidden

    def get_initial_state(self, batch_size):
        return self.initial_hidden_state.repeat(batch_size, 1)

In [19]:
class ExponentialDecayLR(_LRScheduler):
    def __init__(self, optimizer, initial_lr, decay_rate, decay_steps, last_epoch=-1):
        self.initial_lr = initial_lr
        self.decay_rate = decay_rate
        self.decay_steps = decay_steps
        super(ExponentialDecayLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        return [self.initial_lr * self.decay_rate ** (self.last_epoch / self.decay_steps) 
                for _ in self.base_lrs]

# Example usage
#optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
#scheduler = ExponentialDecayLR(optimizer, initial_lr=0.1, decay_rate=0.96, decay_steps=100)
#for epoch in range(1000):
#    train(...)
#    validate(...)
#    scheduler.step()

In [20]:
def plot_loss_curves(train_loss, val_loss):
    fig, ax = plt.subplots()
    ax.plot(train_loss, label="Training Loss")
    ax.plot(val_loss, label="Validation Loss")
    ax.set_xlabel("Steps")
    ax.set_ylabel("Loss")
    ax.set_title("Training and Validation Loss")
    ax.legend()
    plt.show()
    
    return fig

def plot_eigenvalues(wR, multilayer: bool = False):
    fig, ax = plt.subplots()
    x = np.linspace(-1, 1, 1000)
    ax.plot(x, np.sqrt(1 - x**2), 'k')
    ax.plot(x, -np.sqrt(1 - x**2), 'k')
    if multilayer:
        for i, weights in enumerate(wR):
            eigvals, _ = np.linalg.eig(weights)
            ax.plot(np.real(eigvals), np.imag(eigvals), '.', label=f"Layer {i}")
        ax.legend()
    else:
        eigvals, _ = np.linalg.eig(wR)
        ax.plot(np.real(eigvals), np.imag(eigvals), '.')
    ax.set_title("Eigenvalues of Recurrent Weights")
    ax.set_xlabel("Real")
    ax.set_ylabel("Imaginary")
    # equal axis
    ax.set_aspect('equal', adjustable='box')
    plt.show()    
    return fig

In [22]:
model = mGRU(input_size=250, hidden_size=64, device=mps_device)
model.to(mps_device)

mGRU(
  (sigmoid): Sigmoid()
  (tanh): Tanh()
  (Wz): Linear(in_features=314, out_features=64, bias=True)
  (Wh): Linear(in_features=314, out_features=64, bias=True)
)

In [40]:
def main(
    dataset_size=5000,
    val_size=500,
    batch_size=32,
    initial_lr=3e-3,
    decay_rate=0.999,
    steps=500,
    hidden_size=32,
    depth=1,
    seed=5678,
    device="cpu",
):
    # Set random seed for reproducibility
    torch.manual_seed(seed)
    
    # Generate data
    xs, ys, _ = get_data_densematrix(dataset_size + val_size, 250, 250, 
                                     seed=seed, 
                                     mask_prob=1/2, 
                                     rank_threshold=0.2, 
                                     measurement_type="mask", 
                                     include_X=False)
    xs_train, ys_train = xs[:dataset_size], ys[:dataset_size]
    xs_val, ys_val = xs[dataset_size:], ys[dataset_size:]
    print("Generated training and validation data.")
    
    # Create DataLoader
    train_dataset = TensorDataset(xs_train, ys_train)
    val_dataset = TensorDataset(xs_val, ys_val)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                              shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, 
                            shuffle=False, num_workers=4)
    print("Created DataLoaders.")
    
    # Create the model
    inputdim = xs_train.shape[-1]
    model = mGRU(input_size=inputdim, hidden_size=hidden_size, 
                 binaryoutput=True,
                 device=device).to(device)
    print("mGRU model initialized and sent to device.")
    
    # Save the initial recurrent weights
    initWz = model.Wz.weight.clone().detach().cpu().numpy()
    print("Initial recurrent weights saved.")
    
    # Define loss function (binary cross entropy) and optimizer
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=initial_lr)
    print("Adam optimizer initialized.")
    
    # Create learning rate schedule
    scheduler = ExponentialDecayLR(optimizer, 
                                   initial_lr=initial_lr, 
                                   decay_rate=decay_rate, 
                                   decay_steps=steps)
    print("Learning rate schedule created.")
    
    # Training loop
    losses = []
    val_losses = []
    print("Training starting.")
    for step in range(steps):
        model.train()
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs, hidden = model(x_batch)
            loss = criterion(outputs, y_batch.float().squeeze())
            loss.backward()
            optimizer.step()
            scheduler.step()
        
        # Validation loss
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x_batch, y_batch in val_loader:
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                outputs, hidden = model(x_batch)
                val_loss += criterion(outputs, y_batch.float().squeeze()).item()
        val_loss /= len(val_loader)
        
        # Append the training and validation loss
        losses.append(loss.item())
        val_losses.append(val_loss)
        
        # Print the loss every 50 steps, and at the final step
        if step % 50 == 0 or step == steps - 1:
            print(f"step={step}, loss={np.round(loss.item(),6)}, val_loss={np.round(val_loss,6)}")
        
        # Middle of training, save the recurrent weights
        if step == steps // 2:
            midWz = model.Wz.weight.clone().detach().cpu().numpy()
    
    endWz = model.Wz.weight.clone().detach().cpu().numpy()
    
    # Final accuracy
    model.eval()
    with torch.no_grad():
        pred_ys = model(xs_train.to(device))
        num_correct = ((pred_ys > 0.5) == ys_train.to(device)).sum().item()
        final_accuracy = num_correct / dataset_size
    print(f"Final accuracy = {np.round(final_accuracy,6)}")
    
    # Plot training and validation loss curves
    lossfig = plot_loss_curves(losses, val_losses)
    
    # Save the model
    return model, lossfig, initWz, midWz, endWz

In [42]:
model, lossfig, initWz, midWz, endWz = main(dataset_size=1000,
    val_size=100,
    batch_size=16,
    steps=10,
    device=mps_device)

Generated training and validation data.
Created DataLoaders.
mGRU model initialized and sent to device.
Initial recurrent weights saved.
Adam optimizer initialized.
Learning rate schedule created.
Training starting.
step=0, loss=0.738697, val_loss=0.709023
step=9, loss=0.015143, val_loss=0.938615


TypeError: '>' not supported between instances of 'tuple' and 'float'

In [None]:
initWz.shape, midWz.shape, endWz.shape

In [None]:
# Create 3 x 3 panel of plots showing the eigenvalues of the reset, input, and new weights, 
# at initialization, middle of training, and end of training.
fig, axs = plt.subplots(3, 3, figsize=(13, 13))
rowlabels = ["Initialization", "Middle of Training", "End of Training"]
x = np.linspace(-1, 1, 1000)
for i, (reset, inp, new) in enumerate(zip(resetW, inputW, newW)):
    for j, weights in enumerate([reset, inp, new]):
        axs[i, j].plot(x, np.sqrt(1 - x**2), 'k')
        axs[i, j].plot(x, -np.sqrt(1 - x**2), 'k')
        eigvals, _ = np.linalg.eig(weights)
        axs[i, j].plot(np.real(eigvals), np.imag(eigvals), '.')
        axs[i, j].set_title(f"{rowlabels[i]}, {['Reset gate', 'Input gate', 'New hidden state'][j]}")
        axs[i, j].set_xlabel("Real")
        axs[i, j].set_ylabel("Imaginary")
        #axs[i, j].set_aspect('equal', adjustable='box')
#plt.suptitle("Eigenvalues of Recurrent Weights")
fig.tight_layout()
plt.show()

In [None]:
# Test the trained model on a new test set and evaluate the accuracy
test_xs, test_ys, _ = get_data_densematrix(500, 250, 250,
                                             seed=1234, 
                                             mask_prob=1/2, 
                                             rank_threshold=0.2, 
                                             measurement_type="mask", 
                                             include_X=False)
model.eval()
with torch.no_grad():
    pred_test_ys = model(test_xs.to(mps_device))
    num_correct = ((pred_test_ys > 0.5) == test_ys.to(mps_device)).sum().item()
    test_accuracy = num_correct / 500
print(f"Test accuracy = {np.round(test_accuracy,6)}")

In [None]:
# Plot ROC curve
fpr, tpr, _ = roc_curve(test_ys, pred_test_ys)
roc_auc = auc(fpr, tpr)

plt.figure()
plt.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (area = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.xlabel("False Positive Rate, 1 - Specificity")
plt.ylabel("True Positive Rate, Sensitivity")
plt.title("Receiver Operating Characteristic")
plt.legend(loc="lower right")
plt.show()