In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
from scipy.optimize import curve_fit
from scipy.stats import norm
from tqdm import tqdm
import matplotlib.pyplot as plt

# -----------------------
# 1. SIMULATORS
# -----------------------
def simulate_sir(beta, gamma, T=100, noise=0.01):
    S, I, R = [0.99], [0.01], [0.0]
    for _ in range(T - 1):
        dS = -beta * S[-1] * I[-1]
        dI = beta * S[-1] * I[-1] - gamma * I[-1]
        dR = gamma * I[-1]
        S.append(S[-1] + dS)
        I.append(I[-1] + dI)
        R.append(R[-1] + dR)
    I_noisy = np.array(I) + np.random.normal(0, noise, T)
    return np.clip(I_noisy, 0, 1)

def simulate_seir(beta, sigma, gamma, T=100, noise=0.01):
    S, E, I, R = [0.99], [0.0], [0.01], [0.0]
    for _ in range(T - 1):
        dS = -beta * S[-1] * I[-1]
        dE = beta * S[-1] * I[-1] - sigma * E[-1]
        dI = sigma * E[-1] - gamma * I[-1]
        dR = gamma * I[-1]
        S.append(S[-1] + dS)
        E.append(E[-1] + dE)
        I.append(I[-1] + dI)
        R.append(R[-1] + dR)
    I_noisy = np.array(I) + np.random.normal(0, noise, T)
    return np.clip(I_noisy, 0, 1)

# -----------------------
# 2. DATASET
# -----------------------
class EpidemicDataset(Dataset):
    def __init__(self, N=1000, T=100):
        self.data = []
        for _ in range(N // 2):
            beta, gamma = np.random.uniform(0.1, 0.5), np.random.uniform(0.05, 0.2)
            series = simulate_sir(beta, gamma, T)
            self.data.append((series, 0))  # label 0 = SIR
        for _ in range(N // 2):
            beta, sigma, gamma = np.random.uniform(0.1, 0.5), np.random.uniform(0.1, 0.3), np.random.uniform(0.05, 0.2)
            series = simulate_seir(beta, sigma, gamma, T)
            self.data.append((series, 1))  # label 1 = SEIR
        np.random.shuffle(self.data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x, y = self.data[idx]
        return torch.tensor(x, dtype=torch.float32).unsqueeze(0), torch.tensor(y)

# -----------------------
# 3. SGNN Model
# -----------------------
class ResidualBlock(nn.Module):
    def __init__(self, channels, kernel_size=5, padding=2):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)
        self.gn1 = nn.GroupNorm(4, channels)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)
        self.gn2 = nn.GroupNorm(4, channels)

    def forward(self, x):
        identity = x
        out = F.gelu(self.gn1(self.conv1(x)))
        out = self.gn2(self.conv2(out))
        return F.gelu(out + identity)

class SGNNClassifier(nn.Module):
    def __init__(self, T):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv1d(1, 64, kernel_size=7, padding=3),
            nn.GELU(),
            nn.GroupNorm(4, 64)
        )

        self.backbone = nn.Sequential(
            # Stage 1
            ResidualBlock(64),
            ResidualBlock(64),
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.GELU(),

            # Stage 2
            ResidualBlock(128),
            ResidualBlock(128),
            nn.Conv1d(128, 256, kernel_size=3, padding=1),
            nn.GELU(),

            # Stage 3
            ResidualBlock(256),
            ResidualBlock(256),
        )

        self.pooling = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.AdaptiveMaxPool1d(1)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 2, 512),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.GELU(),
            nn.Linear(64, 2)
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.backbone(x)
        avg_pool = self.pooling[0](x)
        max_pool = self.pooling[1](x)
        x = torch.cat([avg_pool, max_pool], dim=1)
        return self.classifier(x)



# -----------------------
# 4. AIC-based Classifier
# -----------------------
def aic_sir(ts):
    def sir_model(t, beta, gamma):
        S, I, R = 0.99, 0.01, 0.0
        result = []
        for _ in t:
            dS = -beta * S * I
            dI = beta * S * I - gamma * I
            dR = gamma * I
            S += dS
            I += dI
            R += dR
            result.append(I)
        return result

    try:
        popt, _ = curve_fit(sir_model, np.arange(len(ts)), ts, bounds=(0, [1.0, 1.0]))
        residuals = ts - sir_model(np.arange(len(ts)), *popt)
        sse = np.sum(residuals**2)
        return len(ts) * np.log(sse / len(ts)) + 2 * 2  # AIC = n*ln(RSS/n) + 2k
    except:
        return np.inf

def aic_seir(ts):
    def seir_model(t, beta, sigma, gamma):
        S, E, I, R = 0.99, 0.0, 0.01, 0.0
        result = []
        for _ in t:
            dS = -beta * S * I
            dE = beta * S * I - sigma * E
            dI = sigma * E - gamma * I
            dR = gamma * I
            S += dS
            E += dE
            I += dI
            R += dR
            result.append(I)
        return result

    try:
        popt, _ = curve_fit(seir_model, np.arange(len(ts)), ts, bounds=(0, [1.0, 1.0, 1.0]))
        residuals = ts - seir_model(np.arange(len(ts)), *popt)
        sse = np.sum(residuals**2)
        return len(ts) * np.log(sse / len(ts)) + 2 * 3
    except:
        return np.inf

def aic_predict(ts):
    return 0 if aic_sir(ts) < aic_seir(ts) else 1

# -----------------------
# 5. Training Loop
# -----------------------
def train():
    T = 100
    dataset = EpidemicDataset(N=60000, T=T)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_data, val_data = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_data, batch_size=1536, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=1)

    model = SGNNClassifier(T)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    criterion = nn.CrossEntropyLoss()

    sgnn_errors = []

    # Compute AIC performance on validation set once
    print("Computing AIC performance...")
    correct_aic = 0
    total = 0
    for x, y in tqdm(val_loader):
        ts_np = x.squeeze().numpy()
        pred_aic = aic_predict(ts_np)
        correct_aic += (pred_aic == y.item())
        total += 1
    aic_acc = correct_aic / total
    aic_error = 1 - aic_acc
    print(f"AIC Accuracy = {aic_acc:.3f} | Error = {aic_error:.3f}")

    # Train SGNN
    for epoch in range(1, 21):
        model.train()
        total_loss = 0
        for x, y in tqdm(train_loader):
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Evaluate SGNN on validation set
        model.eval()
        correct_sgnn = 0
        total = 0
        with torch.no_grad():
            for x, y in val_loader:
                pred_sgnn = torch.argmax(model(x)).item()
                correct_sgnn += (pred_sgnn == y.item())
                total += 1

        sgnn_acc = correct_sgnn / total
        sgnn_error = 1 - sgnn_acc
        sgnn_errors.append(sgnn_error)

        print(f"Epoch {epoch}: SGNN Acc = {sgnn_acc:.3f}, Error = {sgnn_error:.3f}, Loss = {total_loss/len(train_loader):.4f}")

    # Plot error vs epoch
    plt.rcParams.update({
        "font.family": "serif",
        "font.size": 14,
        "figure.dpi": 300,
        "text.usetex": False
    })

    fig, ax = plt.subplots(figsize=(7, 5))
    ax.plot(range(1, len(sgnn_errors) + 1), sgnn_errors, label='SGNN Classification Error', marker='o', linewidth=2)
    ax.axhline(y=aic_error, color='black', linestyle='--', label='AIC Error', linewidth=2)

    ax.set_xlabel('Epoch', fontsize=14)
    ax.set_ylabel('Classification Error', fontsize=14)
    ax.set_title('SGNN vs AIC Error Over Epochs', fontsize=15)
    ax.legend(fontsize=13)
    ax.tick_params(axis='both', which='major', labelsize=12)

    plt.tight_layout()
    plt.savefig("sgnn_vs_aic_error.pdf", format="pdf", bbox_inches='tight')
    plt.close()



if __name__ == "__main__":
    train()

Computing AIC performance...


100%|██████████| 12000/12000 [05:59<00:00, 33.42it/s]


AIC Accuracy = 0.908 | Error = 0.092


100%|██████████| 32/32 [02:55<00:00,  5.49s/it]


Epoch 1: SGNN Acc = 0.674, Error = 0.326, Loss = 0.6558


100%|██████████| 32/32 [02:55<00:00,  5.49s/it]


Epoch 2: SGNN Acc = 0.738, Error = 0.262, Loss = 0.5398


100%|██████████| 32/32 [02:55<00:00,  5.48s/it]


Epoch 3: SGNN Acc = 0.805, Error = 0.195, Loss = 0.4655


100%|██████████| 32/32 [02:55<00:00,  5.50s/it]


Epoch 4: SGNN Acc = 0.866, Error = 0.134, Loss = 0.4043


100%|██████████| 32/32 [02:56<00:00,  5.51s/it]


Epoch 5: SGNN Acc = 0.934, Error = 0.066, Loss = 0.2955


100%|██████████| 32/32 [02:56<00:00,  5.51s/it]


Epoch 6: SGNN Acc = 0.943, Error = 0.057, Loss = 0.1936


100%|██████████| 32/32 [02:56<00:00,  5.50s/it]


Epoch 7: SGNN Acc = 0.950, Error = 0.050, Loss = 0.1555


100%|██████████| 32/32 [02:56<00:00,  5.51s/it]


Epoch 8: SGNN Acc = 0.944, Error = 0.056, Loss = 0.1339


100%|██████████| 32/32 [02:56<00:00,  5.50s/it]


Epoch 9: SGNN Acc = 0.951, Error = 0.049, Loss = 0.1276


100%|██████████| 32/32 [02:55<00:00,  5.49s/it]


Epoch 10: SGNN Acc = 0.953, Error = 0.047, Loss = 0.1187


100%|██████████| 32/32 [02:55<00:00,  5.49s/it]


Epoch 11: SGNN Acc = 0.957, Error = 0.043, Loss = 0.1132


100%|██████████| 32/32 [02:54<00:00,  5.46s/it]


Epoch 12: SGNN Acc = 0.932, Error = 0.068, Loss = 0.1111


100%|██████████| 32/32 [02:54<00:00,  5.47s/it]


Epoch 13: SGNN Acc = 0.959, Error = 0.041, Loss = 0.1157


100%|██████████| 32/32 [02:55<00:00,  5.47s/it]


Epoch 14: SGNN Acc = 0.953, Error = 0.047, Loss = 0.1039


100%|██████████| 32/32 [02:55<00:00,  5.47s/it]


Epoch 15: SGNN Acc = 0.941, Error = 0.059, Loss = 0.1082


100%|██████████| 32/32 [02:55<00:00,  5.50s/it]


Epoch 16: SGNN Acc = 0.959, Error = 0.041, Loss = 0.1103


100%|██████████| 32/32 [02:56<00:00,  5.52s/it]


Epoch 17: SGNN Acc = 0.959, Error = 0.041, Loss = 0.0970


100%|██████████| 32/32 [02:56<00:00,  5.52s/it]


Epoch 18: SGNN Acc = 0.955, Error = 0.045, Loss = 0.0964


100%|██████████| 32/32 [02:55<00:00,  5.50s/it]


Epoch 19: SGNN Acc = 0.944, Error = 0.056, Loss = 0.0951


100%|██████████| 32/32 [02:56<00:00,  5.51s/it]


Epoch 20: SGNN Acc = 0.958, Error = 0.042, Loss = 0.0954
