# Janus

We provide a stand-alone implementation of Janus in this jupyter notebook that evaluates our 55k parameter model on the 3 protein fitness (GB1, Gifford, GFP) datasets.


**NOTE:** For hardware optimization, we use the FlashFFTConv library in this implementation of Janus in this demo so we can only confirm that this code will run on A100, H100 GPUs or RTX 3090 and 4090. for more details on GPU requirements, please visit the (FlashFFTConv github: https://github.com/HazyResearch/flash-fft-conv)

# Installation

In [None]:
!python -m pip install --upgrade pip

Collecting pip
  Downloading pip-24.0-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m36.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.1.2
    Uninstalling pip-23.1.2:
      Successfully uninstalled pip-23.1.2
Successfully installed pip-24.0


In [None]:
!pip install einops pandas pytest tqdm scipy pyarrow torchaudio

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl.metadata (13 kB)
Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0
[0m

In [None]:
!pip install torch torchaudio

[0m

In [None]:
# !pip install torch==1.10.0+cpu torchvision==0.11.0+cpu torchaudio==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html

Looking in links: https://download.pytorch.org/whl/torch_stable.html
[31mERROR: Operation cancelled by user[0m[31m
[0m

# Imports

In [None]:
import numpy as np
import pandas as pd
import json
import math

import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio

from torch.utils.data import Dataset, DataLoader
from einops import rearrange, repeat
from tqdm.auto import tqdm
from scipy.stats import spearmanr

dropout_fn = nn.Dropout1d
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
!unzip negative_input_processed.zip
!unzip positive_input_processed.zip

Archive:  negative_input_processed.zip
   creating: negative_input_processed/
  inflating: negative_input_processed/neg_10_16.mp3  
  inflating: negative_input_processed/neg_01_18.mp3  
  inflating: negative_input_processed/neg_05_03.mp3  
  inflating: negative_input_processed/neg_07_12.mp3  
  inflating: negative_input_processed/neg_06_18.mp3  
  inflating: negative_input_processed/neg_06_05.mp3  
  inflating: negative_input_processed/neg_05_11.mp3  
  inflating: negative_input_processed/neg_03_19.mp3  
  inflating: negative_input_processed/neg_01_10.mp3  
  inflating: negative_input_processed/neg_06_19.mp3  
  inflating: negative_input_processed/neg_07_09.mp3  
  inflating: negative_input_processed/neg_05_06.mp3  
  inflating: negative_input_processed/neg_08_11.mp3  
  inflating: negative_input_processed/neg_10_09.mp3  
  inflating: negative_input_processed/neg_03_12.mp3  
  inflating: negative_input_processed/neg_03_13.mp3  
  inflating: negative_input_processed/neg_07_17.mp3  
  in

# Dataloader

Here we provide the dataloader for the audio call prediction tasks.

In [84]:
def get_audio_files_and_labels(audio_folders):
    files_and_labels = []
    for folder in audio_folders:
        label = np.array([0, 1]) if "positive" in folder else np.array([1, 0])
        for file in os.listdir(folder):
            files_and_labels.append((os.path.join(folder, file), label))
    return files_and_labels

audio_folders = ['negative_input_processed', 'positive_input_processed']
all_data = get_audio_files_and_labels(audio_folders)

In [85]:
from sklearn.model_selection import train_test_split

train_data, test_data = train_test_split(all_data, test_size=0.2, random_state=42)


In [86]:
len(train_data)

253

In [87]:
from torch.utils.data import Dataset, DataLoader


class AudioDataset(Dataset):
    def __init__(self, annotations, transformation, target_sample_rate, num_samples, device='cpu'):
        self.annotations = annotations
        self.transformation = transformation.to(device)
        self.target_sample_rate = target_sample_rate
        self.num_samples = num_samples
        self.device = device

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        audio_sample_path, label = self.annotations[index]
        # print(audio_sample_path)
        signal, sr = torchaudio.load(audio_sample_path)
        signal = signal.to(self.device)
        if sr != self.target_sample_rate:
            # raise Exception(f"Sample rate {sr} is not {self.target_sample_rate}.")
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.target_sample_rate).to(self.device)
            signal = resampler(signal)
        # print(signal.shape)
        if signal.shape[0] > 1:
            signal = torch.mean(signal, dim=0, keepdim=True)
        if signal.shape[1] > self.num_samples:
            signal = signal[:, :self.num_samples]
        elif signal.shape[1] < self.num_samples:
            padding = self.num_samples - signal.shape[1]
            signal = torch.cat([signal, torch.zeros(1, padding).to(self.device)], dim=1)
        signal = self.transformation(signal)
        return torch.squeeze(signal.permute(0, 2, 1)), label

# Example usage
transformation = torchaudio.transforms.Spectrogram()
target_sample_rate = 16000  # Define according to your needs
num_samples = target_sample_rate * 15  # 15 seconds of audio

train_dataset = AudioDataset(train_data, transformation, target_sample_rate, num_samples)
test_dataset = AudioDataset(test_data, transformation, target_sample_rate, num_samples)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)


In [88]:
samples, labels = next(iter(train_loader))
labels.shape

torch.Size([16, 2])

# RMSNorm

In [89]:
class RMSNorm(nn.Module):
    def __init__(self, d, p=-1., eps=1e-8, bias=False):
        """
            Root Mean Square Layer Normalization
        :param d: model size
        :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
        :param eps:  epsilon value, default 1e-8
        :param bias: whether use bias term for RMSNorm, disabled by
            default because RMSNorm doesn't enforce re-centering invariance.
        """
        super(RMSNorm, self).__init__()

        self.eps = eps
        self.d = d
        self.p = p
        self.bias = bias

        self.scale = nn.Parameter(torch.ones(d))
        self.register_parameter("scale", self.scale)

        if self.bias:
            self.offset = nn.Parameter(torch.zeros(d))
            self.register_parameter("offset", self.offset)

    def forward(self, x):
        if self.p < 0. or self.p > 1.:
            norm_x = x.norm(2, dim=-1, keepdim=True)
            d_x = self.d
        else:
            partial_size = int(self.d * self.p)
            partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)

            norm_x = partial_x.norm(2, dim=-1, keepdim=True)
            d_x = partial_size

        rms_x = norm_x * d_x ** (-1. / 2)
        x_normed = x / (rms_x + self.eps)

        if self.bias:
            return self.scale * x_normed + self.offset

        return self.scale * x_normed

# S4D

In [68]:
class DropoutNd(nn.Module):
    def __init__(self, p: float = 0.5, tie=True, transposed=True):
        """
        tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
        """
        super().__init__()
        if p < 0 or p >= 1:
            raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
        self.p = p
        self.tie = tie
        self.transposed = transposed
        self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)

    def forward(self, X):
        """X: (batch, dim, lengths...)."""
        if self.training:
            if not self.transposed: X = rearrange(X, 'b ... d -> b d ...')
            # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow because of CPU -> GPU copying
            mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape
            # mask = self.binomial.sample(mask_shape)
            mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p
            X = X * mask * (1.0/(1-self.p))
            if not self.transposed: X = rearrange(X, 'b d ... -> b ... d')
            return X
        return X

class S4DKernel(nn.Module):
    """Generate convolution kernel from diagonal SSM parameters."""

    def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, lr=None):
        super().__init__()
        # Generate dt
        H = d_model
        log_dt = torch.rand(H) * (
            math.log(dt_max) - math.log(dt_min)
        ) + math.log(dt_min)

        C = torch.randn(H, N // 2, dtype=torch.cfloat)
        self.C = nn.Parameter(torch.view_as_real(C))
        self.register("log_dt", log_dt, lr)

        log_A_real = torch.log(0.5 * torch.ones(H, N//2))
        A_imag = math.pi * repeat(torch.arange(N//2), 'n -> h n', h=H)
        self.register("log_A_real", log_A_real, lr)
        self.register("A_imag", A_imag, lr)

    def forward(self, L):
        """
        returns: (..., c, L) where c is number of channels (default 1)
        """

        # Materialize parameters
        dt = torch.exp(self.log_dt) # (H)
        C = torch.view_as_complex(self.C) # (H N)
        A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N)

        # Vandermonde multiplication
        dtA = A * dt.unsqueeze(-1)  # (H N)
        K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L)
        C = C * (torch.exp(dtA)-1.) / A
        K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real

        return K

    def register(self, name, tensor, lr=None):
        """Register a tensor with a configurable learning rate and 0 weight decay"""

        if lr == 0.0:
            self.register_buffer(name, tensor)
        else:
            self.register_parameter(name, nn.Parameter(tensor))

            optim = {"weight_decay": 0.0}
            if lr is not None: optim["lr"] = lr
            setattr(getattr(self, name), "_optim", optim)


class S4D(nn.Module):
    def __init__(self, d_model, d_state=64, dropout=0.0, transposed=True, **kernel_args):
        super().__init__()

        self.h = d_model
        self.n = d_state
        self.d_output = self.h
        self.transposed = transposed

        self.D = nn.Parameter(torch.randn(self.h))

        # SSM Kernel
        self.kernel = S4DKernel(self.h, N=self.n, **kernel_args)

        # Pointwise
        self.activation = nn.GELU()
        # dropout_fn = nn.Dropout2d # NOTE: bugged in PyTorch 1.11
        dropout_fn = DropoutNd
        self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()

        # position-wise output transform to mix features
        self.output_linear = nn.Sequential(
            nn.Conv1d(self.h, 2*self.h, kernel_size=1),
            nn.GLU(dim=-2),
        )

    def forward(self, u, **kwargs): # absorbs return_output and transformer src mask
        """ Input and output shape (B, H, L) """
        if not self.transposed: u = u.transpose(-1, -2)
        L = u.size(-1)

        # Compute SSM Kernel
        k = self.kernel(L=L) # (H L)

        # Convolution
        k_f = torch.fft.rfft(k, n=2*L) # (H L)
        u_f = torch.fft.rfft(u, n=2*L) # (B H L)
        y = torch.fft.irfft(u_f*k_f, n=2*L)[..., :L] # (B H L)

        # Compute D term in state space equation - essentially a skip connection
        y = y + u * self.D.unsqueeze(-1)

        y = self.dropout(self.activation(y))
        y = self.output_linear(y)
        if not self.transposed: y = y.transpose(-1, -2)
        return y

# Janus

In [72]:
class Janus(nn.Module):
    def __init__(
        self,
        d_input,
        d_output=10,
        d_model=256,
        n_layers=4,
        dropout=0.2,
        prenorm=True,
    ):
        super().__init__()

        # Linear encoder
        self.encoder = nn.Linear(d_input, d_model)
        self.prenorm = prenorm

        # Stack S4 layers as residual blocks
        self.s4_layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        for _ in range(n_layers):
            self.s4_layers.append(
                S4D(d_model, dropout=dropout, transposed=True, lr=min(0.001, 0.002))
            )
            self.norms.append(RMSNorm(d_model))
            self.dropouts.append(dropout_fn(dropout))

        # Linear decoder
        self.decoder = nn.Linear(d_model, d_output)

    def forward(self, x):
        """
        Input x is shape (B, L, d_input)
        """
        x = self.encoder(x)  # (B, L, d_input) -> (B, L, d_model)
        #print(x.shape)
        x = x.transpose(-1, -2)  # (B, L, d_model) -> (B, d_model, L)
        for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts):
            # Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L)
            z = x
            if self.prenorm:
                # Prenorm
                z = norm(z.transpose(-1, -2)).transpose(-1, -2)

            # Apply S4 block: we ignore the state input and output
            z = layer(z)

            # Dropout on the output of the S4 block
            z = dropout(z)

            # Residual connection
            x = z + x

            if not self.prenorm:
                # Postnorm
                x = norm(x.transpose(-1, -2)).transpose(-1, -2)

        x = x.transpose(-1, -2)

        # Pooling: average pooling over the sequence length
        x = x.mean(dim=1)

        # Decode the outputs
        x = self.decoder(x)  # (B, d_model) -> (B, d_output)

        return x

# Model Instantiation

This model instantiation is provided as a sanity check, the training loop in the next section will instantiate a new model for each task.

In [117]:
model_config = dict(
    d_model=32,
    n_layers=1,
    dropout=0.2,
    d_input=201, #input dim
    d_output=2, #num classes
    prenorm=True,
)

d_model = model_config['d_model']
n_layers = model_config['n_layers']

model = Janus(**model_config).to(device)

In [118]:
# Print model architecture
print(model)

# Count total trainable parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {num_params}")

Janus(
  (encoder): Linear(in_features=201, out_features=32, bias=True)
  (s4_layers): ModuleList(
    (0): S4D(
      (kernel): S4DKernel()
      (activation): GELU(approximate='none')
      (dropout): DropoutNd()
      (output_linear): Sequential(
        (0): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
        (1): GLU(dim=-2)
      )
    )
  )
  (norms): ModuleList(
    (0): RMSNorm()
  )
  (dropouts): ModuleList(
    (0): Dropout1d(p=0.2, inplace=False)
  )
  (decoder): Linear(in_features=32, out_features=2, bias=True)
)
Total trainable parameters: 12834


In [119]:
sample, label = next(iter(train_loader))
outputs = model(sample.to(device))

In [120]:
outputs.log_softmax(dim=1).exp()

tensor([[7.4372e-01, 2.5628e-01],
        [6.9853e-01, 3.0147e-01],
        [8.1746e-01, 1.8254e-01],
        [7.0092e-01, 2.9908e-01],
        [7.3801e-01, 2.6199e-01],
        [9.9960e-01, 3.9687e-04],
        [6.0693e-01, 3.9307e-01],
        [8.8141e-01, 1.1859e-01],
        [9.9890e-01, 1.0966e-03],
        [6.9760e-01, 3.0240e-01],
        [5.4452e-01, 4.5548e-01],
        [6.1720e-01, 3.8280e-01],
        [5.7418e-01, 4.2582e-01],
        [7.8751e-01, 2.1249e-01],
        [7.4590e-01, 2.5410e-01],
        [6.8831e-01, 3.1169e-01]], device='cuda:0', grad_fn=<ExpBackward0>)

In [121]:
label.shape

torch.Size([16, 2])

# Training and Evaluation Loops

In [122]:
# Here is the updated train loop for binary classfication:

import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from scipy.stats import spearmanr
from sklearn.metrics import roc_auc_score, accuracy_score

# Assuming you have defined the model, dataloaders, and other necessary components above

# Set the number of epochs
num_epochs = 20
best_loss = 5

# Initialize the model
model = Janus(**model_config).to(device)

# Loss function and optimizer
criterion = nn.BCELoss().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

pbar = tqdm(range(num_epochs), desc=f"Epoch 0: Train Acc: 0.0000, Val Acc: 0.0000, Loss: 0.0000")
for epoch in pbar:
    # Lists for storing metrics
    train_labels_list, train_outputs_list = [], []
    val_labels_list, val_outputs_list = [], []

    # Training phase
    model.train()
    for sequences, labels in train_loader:
        sequences = sequences.to(device)
        labels = labels.to(device)

        # Forward pass
        logits = model(sequences)
        # print(outputs)
        # argmax to get the output
        outputs = logits.log_softmax(dim=1).exp()
        # outputs = torch.argmax(outputs, dim=1, keepdim=True).float()
        # Calculate loss
        labels = labels.float()
        loss = criterion(outputs, labels)
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Store outputs and labels for calculating metrics
        outputs = torch.argmax(outputs, dim=1, keepdim=True)
        labels = torch.argmax(labels, dim=1, keepdim=True)
        train_outputs_list.append(outputs.detach().cpu().numpy())
        train_labels_list.append(labels.detach().cpu().numpy())

    # Calculate training accuracy and AUC
    train_outputs = np.vstack(train_outputs_list)
    train_labels = np.vstack(train_labels_list)
    train_acc = (train_outputs == train_labels).sum().item() / train_outputs.shape[0]

    # Validation phase
    model.eval()
    with torch.no_grad():
        for sequences, labels in test_loader:
            sequences = sequences.to(device)
            labels = labels.to(device)

            logits = model(sequences)
            outputs = logits.log_softmax(dim=1).exp()
            outputs = torch.argmax(outputs, dim=1, keepdim=True)
            labels = torch.argmax(labels, dim=1, keepdim=True)
            # Store outputs and labels for calculating metrics
            val_outputs_list.append(outputs.cpu().numpy())
            val_labels_list.append(labels.cpu().numpy())

    # Calculate validation accuracy and AUC
    val_outputs = np.vstack(val_outputs_list)
    val_labels = np.vstack(val_labels_list)
    val_acc = (val_outputs == val_labels).sum().item() / val_outputs.shape[0]

    # Update best AUC and save the model if improved
    if loss.item() < best_loss:
        best_loss = loss.item()
        torch.save(model.state_dict(), f'Janus_{d_model}_{n_layers}.pt')

    pbar.set_description(f"Epoch {epoch}: Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Loss: {loss.item():.4f}")

# model.load_state_dict(torch.load("INSERT FILENAME.pt"))

Epoch 19: Train Acc: 0.9881, Val Acc: 0.9844, Loss: 0.0036: 100%|██████████| 20/20 [01:54<00:00,  5.74s/it]
