In [1]:
import torch
import torch.nn as nn

In [2]:
VOCAB = ['A', 'C', 'G', 'T']
BASE2IDX = {'A': 0, 'C': 1, 'G': 2, 'T': 3, '-': 4, 'N': 5}
VOCAB_SIZE = len(BASE2IDX)

In [3]:
class RowAttention(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)

    def forward(self, x):
        # x: (B, N=4, L, D)
        B, N, L, D = x.shape
        x = x.transpose(1, 2)         # (B, L, N, D)
        x = x.reshape(B * L, N, D)
        x, _ = self.attn(x, x, x)
        x = x.reshape(B, L, N, D).transpose(1, 2)  # (B, N, L, D)
        return x

In [4]:
class ColumnAttention(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)

    def forward(self, x):
        # x: (B, N=4, L, D)
        B, N, L, D = x.shape
        x = x.reshape(B * N, L, D)
        x, _ = self.attn(x, x, x)
        x = x.reshape(B, N, L, D)
        return x

In [5]:
class MSABlock(nn.Module):
    def __init__(self, dim, heads, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)

        self.row_attn = RowAttention(dim, heads)
        self.col_attn = ColumnAttention(dim, heads)

        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim), # 
            nn.Dropout(dropout)
        )
    def forward(self, x):
        # Row Attention + residual
        x = x + self.row_attn(self.norm1(x))
        # Column Attention + residual
        x = x + self.col_attn(self.norm2(x))
        # Feed‑Forward + residual
        x = x + self.ffn(self.norm3(x))
        return x

In [6]:
class QuartetTransformer(nn.Module):
    def __init__(self, seq_len=512, dim=64, heads=4, layers=4):
        super().__init__()
        self.embedding = nn.Embedding(VOCAB_SIZE, dim)
        self.taxon_embed = nn.Embedding(4, dim)  # 4 fixed taxa

        self.blocks = nn.ModuleList([
            MSABlock(dim, heads) for _ in range(layers)
        ])

        self.norm = nn.LayerNorm(dim)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),  # pool over taxa × sites -> (B, D, 1, 1)
            nn.Flatten(), # (B, D)
            nn.Linear(dim, 3)  # (B, 3)
        )

    def forward(self, x):
        # x: (B, 4, L), each entry ∈ [0, VOCAB_SIZE-1]
        B, N, L = x.shape
        x = self.embedding(x)  # bব 
        taxon_ids = torch.arange(N, device=x.device).unsqueeze(0).expand(B, N)
        taxon_embed = self.taxon_embed(taxon_ids).unsqueeze(2)  # (B, 4, 1, D)
        x = x + taxon_embed

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)

        x = x.permute(0, 3, 1, 2)  # (B, D, 4, L) -> Quick Fix. Need to recheck.
        return self.classifier(x)  # (B, 3)

In [7]:
import random

def random_msa(length=256):
    """Simulate a random 4×L MSA."""
    bases = list(BASE2IDX.keys())[:-1]  # avoid N
    msa = [
        [random.choice(bases) for _ in range(length)]
        for _ in range(4)
    ]
    return msa

In [8]:
def encode_msa(msa_4seq):
    # msa_4seq: list of 4 strings (each of length L)
    return torch.tensor([
        [BASE2IDX.get(base.upper(), BASE2IDX['N']) for base in seq]
        for seq in msa_4seq
    ], dtype=torch.long)  # shape: (4, L)

In [9]:
def random_mutate(seq, rate):
    """Mutate each base with probability `rate` to one of A/C/G/T (not including '-' or 'N')"""
    return [
        base if random.random() > rate else random.choice([b for b in VOCAB if b != base])
        for base in seq
    ]

In [10]:
# Helper: mutate sequence with substitutions and deletions
def mutate(seq, sub_rate=0.1, del_rate=0.05):
    new_seq = []
    for base in seq:
        if random.random() < del_rate:
            # simulate a deletion by converting to a gap
            new_seq.append('-')
        else:
            if random.random() < sub_rate:
                # substitute to a different nucleotide
                new_base = random.choice([b for b in "ACGT" if b != base])
            else:
                new_base = base
            new_seq.append(new_base)
    return new_seq

In [11]:
def simulate_quartet_msa(seq_len, label):

    # Generate a random root sequence
    root = [random.choice("ACGT") for _ in range(seq_len)]

    # Evolve from root to an internal node on left and right of tree
    left_anc = mutate(root)   # internal ancestor for one split
    right_anc = mutate(root)  # internal ancestor for the other split

    # For each of the three possible topologies, split the two internal nodes into the four leaves
    if label == 0:  # topology ((A,B),(C,D))
        seqA = mutate(left_anc)
        seqB = mutate(left_anc)
        seqC = mutate(right_anc)
        seqD = mutate(right_anc)
    elif label == 1:  # topology ((A,C),(B,D))
        seqA = mutate(left_anc)
        seqC = mutate(left_anc)
        seqB = mutate(right_anc)
        seqD = mutate(right_anc)
    else:  # label == 2: topology ((A,D),(B,C))
        seqA = mutate(left_anc)
        seqD = mutate(left_anc)
        seqB = mutate(right_anc)
        seqC = mutate(right_anc)

    # The simulated MSA is the list of leaf sequences (strings)
    msa = ["".join(seqA), "".join(seqB), "".join(seqC), "".join(seqD)]
    return msa

In [12]:
def simulate_quartet_msa_gpt(seq_len, label):
    ancestor = [random.choice(VOCAB) for _ in range(seq_len)]

    if label == 0:
        ab_common = mutate(ancestor, 0.05)
        cd_common = mutate(ancestor, 0.05)
        A = mutate(ab_common, 0.01)
        B = mutate(ab_common, 0.01)
        C = mutate(cd_common, 0.01)
        D = mutate(cd_common, 0.01)

    elif label == 1:
        ac_common = mutate(ancestor, 0.05)
        bd_common = mutate(ancestor, 0.05)
        A = mutate(ac_common, 0.01)
        C = mutate(ac_common, 0.01)
        B = mutate(bd_common, 0.01)
        D = mutate(bd_common, 0.01)

    else:
        ad_common = mutate(ancestor, 0.05)
        bc_common = mutate(ancestor, 0.05)
        A = mutate(ad_common, 0.01)
        D = mutate(ad_common, 0.01)
        B = mutate(bc_common, 0.01)
        C = mutate(bc_common, 0.01)

    return [A, B, C, D]  # shape: (4, seq_len)

In [13]:
from torch.utils.data import Dataset
class SimulatedQuartetDataset(Dataset):
    def __init__(self, size=1000, seq_len=256):
        self.size = size
        self.seq_len = seq_len
        self.data = [self._generate_sample() for _ in range(size)]

    def _generate_random_sample(self):
        msa = random_msa(self.seq_len)
        label = random.randint(0, 2)  # 3 possible topologies
        return encode_msa(msa), label

    def _generate_sample(self):
        label = random.randint(0, 2)
        msa = simulate_quartet_msa(self.seq_len, label)
        return encode_msa(msa), label

    def __getitem__(self, idx):
        x, y = self.data[idx]
        return x, y

    def __len__(self):
        return self.size

In [14]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn

In [15]:
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss, correct = 0, 0

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(dim=1) == y).sum().item()

    return total_loss / len(dataloader.dataset), correct / len(dataloader.dataset)

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = QuartetTransformer(seq_len=256).to(device)

train_ds = SimulatedQuartetDataset(size=2000, seq_len=512)
val_ds = SimulatedQuartetDataset(size=500, seq_len=512)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)

In [17]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

from tqdm import tqdm

for epoch in range(1, 12):
    model.train()
    total_loss, correct = 0, 0
    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch} Training", leave=False)
    for x, y in train_loader_tqdm:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(dim=1) == y).sum().item()
        train_loader_tqdm.set_postfix(loss=loss.item(), accuracy=correct / ((train_loader_tqdm.n + 1) * x.size(0)))

    train_loss = total_loss / len(train_loader.dataset)
    train_acc = correct / len(train_loader.dataset)

    model.eval()
    total_loss, correct = 0, 0
    val_loader_tqdm = tqdm(val_loader, desc=f"Epoch {epoch} Validation", leave=False)
    with torch.no_grad():
        for x, y in val_loader_tqdm:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)

            total_loss += loss.item() * x.size(0)
            correct += (logits.argmax(dim=1) == y).sum().item()
            val_loader_tqdm.set_postfix(loss=loss.item(), accuracy=correct / ((val_loader_tqdm.n + 1) * x.size(0)))

    val_loss = total_loss / len(val_loader.dataset)
    val_acc = correct / len(val_loader.dataset)

    print(f"Epoch {epoch}:")
    print(f"  Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
    print(f"  Val   Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")

torch.save(model.state_dict(), "quartet_model.pt")

                                                                                             

Epoch 1:
  Train Loss: 1.1042, Accuracy: 0.3355
  Val   Loss: 1.1071, Accuracy: 0.3720


                                                                                              

Epoch 2:
  Train Loss: 1.0997, Accuracy: 0.3495
  Val   Loss: 1.0920, Accuracy: 0.4280


                                                                                             

Epoch 3:
  Train Loss: 1.0989, Accuracy: 0.3550
  Val   Loss: 1.0994, Accuracy: 0.3160


                                                                                              

Epoch 4:
  Train Loss: 1.0958, Accuracy: 0.3640
  Val   Loss: 1.0918, Accuracy: 0.3680


                                                                                             

Epoch 5:
  Train Loss: 1.0886, Accuracy: 0.3755
  Val   Loss: 1.0736, Accuracy: 0.5820


                                                                                              

Epoch 6:
  Train Loss: 1.0562, Accuracy: 0.4950
  Val   Loss: 0.9934, Accuracy: 0.7620


                                                                                              

Epoch 7:
  Train Loss: 0.7148, Accuracy: 0.7445
  Val   Loss: 0.4245, Accuracy: 0.9360


                                                                                              

Epoch 8:
  Train Loss: 0.2994, Accuracy: 0.9790
  Val   Loss: 0.1631, Accuracy: 0.9980


                                                                                               

Epoch 9:
  Train Loss: 0.1073, Accuracy: 0.9985
  Val   Loss: 0.0358, Accuracy: 1.0000


                                                                                                

Epoch 10:
  Train Loss: 0.0295, Accuracy: 1.0000
  Val   Loss: 0.0150, Accuracy: 1.0000


                                                                                                 

Epoch 11:
  Train Loss: 0.0120, Accuracy: 1.0000
  Val   Loss: 0.0070, Accuracy: 1.0000




In [18]:
# Example MSA of 4 aligned sequences
msa = [
    "ACGTACGTACGTACGTACGTACGTACGTACGT",  # Taxon A
    "ACGTACGTACGTACGTACGTACGTACGTACGT",  # Taxon B
    "ACGTACGTACGTACGTACGTACGTACGTACGT",  # Taxon C
    "ACGTACGTACGTACGTACGTACGTACGTACGT",  # Taxon D
]

input_tensor = encode_msa(msa).unsqueeze(0).to(device)  # shape: (1, 4, L)

In [19]:
with torch.no_grad():
    logits = model(input_tensor)  # shape: (1, 3)
    prediction = torch.argmax(logits, dim=1).item()  # scalar

In [20]:
topology_map = {
    0: "((A,B),(C,D))",
    1: "((A,C),(B,D))",
    2: "((A,D),(B,C))"
}

print("Predicted topology:", topology_map[prediction])

Predicted topology: ((A,D),(B,C))
