In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import os

# Set device to CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ------------------------------
# Vocabulary (index 0 reserved for <blank> / EOS)
VOCAB = [
    "<blank>", "help", "food", "water", "pain", "yes", "no", "bathroom",
    "sleep", "happy", "sad", "love", "more", "stop", "go", "rest",
    "hurt", "cold", "warm", "I", "you", "like", "good", "bad", "now", "need", "am",
    "the", "my", "it", "this", "that", "what", "when", "where", "why", "how",
    "me", "feel", "think", "know", "want", "have", "can", "will", "would", "could",
    "please", "thank", "you're", "welcome", "sorry", "excuse", "me", "a", "an", "and", "or", "but", "because",
    "today", "tomorrow", "yesterday", "morning", "afternoon", "evening", "night",
    "hungry", "thirsty", "tired", "sick", "better", "worse", "home", "work", "school",
    "friend", "family", "doctor", "medicine", "book", "music", "movie", "game",
    "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten",
    "hot", "cold", "big", "small", "fast", "slow", "loud", "quiet",
    "up", "down", "in", "out", "on", "off", "here", "there",
    ".", ",", "?", "!",  # Punctuation
    "hello", "goodbye",  # Greetings
    "computer", "brain", "eeg", "bci", "technology", # BCI related
]
VOCAB_SIZE = len(VOCAB)
EOS_TOKEN = 0  # End-of-sequence token

# ------------------------------
# Expanded Coherent Training Fragments (Examples - Add more!)
coherent_fragments = [
    [19, 25, 3],        # "I need water"
    [19, 25, 2],        # "I need food"
    [19, 26, 10],       # "I am sad"
    [19, 26, 16],       # "I am hurt"
    [19, 25, 1],        # "I need help"
    [19, 26, 17],       # "I am cold"
    [19, 26, 18],       # "I am warm"
    [19, 25, 7],        # "I need bathroom"
    [19, 26, 9],        # "I am happy"
    [19, 26, 22],       # "I am good"
    [19, 11, 2],        # "I love food"
    [19, 21, 2],        # "I like food"
    [19, 25, 12, 2],    # "I need more food"
    [19, 25, 12, 3],    # "I need more water"
    [19, 11, 3],        # "I love water"
    [13, 14, 22],       # "stop go good"
    [19, 21, 20],       # "I like you"
    [20, 25, 2],        # "you need food"
    [20, 25, 3],        # "you need water"
    [19, 25, 11],       # "I need love"
    [19, 25, 2, 3],     # "I need food water"
    [19, 26, 16, 17],   # "I am hurt cold"
    [19, 25, 11, 2],    # "I need love food"
    [19, 21, 22, 3],    # "I like good water"
    [20, 25, 1, 24],    # "you need help now"
    [19, 25, 27],       # I need my medicine
    [19, 25, 30],       # I need my book
    [19, 25, 31],       # I need my music
    [19, 25, 32],       # I need my movie
    [19, 25, 33],       # I need my game
    [19, 21, 28],       # I like my friend
    [19, 21, 29],       # I like my family
    [19, 21, 34],       # I like my doctor
    [19, 25, 35],       # I need a doctor
    [19, 26, 40],       # I am hungry
    [19, 26, 41],       # I am thirsty
    [19, 26, 42],       # I am tired
    [19, 26, 43],       # I am sick
    [19, 26, 44],       # I feel better
    [19, 26, 45],       # I feel worse
    [19, 26, 46],       # I am home
    [19, 26, 47],       # I am at work
    [19, 26, 48],       # I am at school
    [19, 25, 50],       # I need help now
    [19, 25, 51],       # I need to rest now
    [19, 25, 52],       # I need to sleep now
    [19, 25, 53],       # I need to go to the bathroom now
    [19, 25, 54],       # I need to eat now
    [19, 25, 55],       # I need to drink now
    [19, 25, 56],       # I need to take my medicine now
    [19, 25, 57],       # I need to see a doctor now
    [19, 25, 58],       # I need my family
    [19, 25, 59],       # I need my friend
    [19, 25, 60],       # I need my computer
    [19, 25, 61],       # I need my brain
    [19, 25, 62],       # I need my eeg
    [19, 25, 63],       # I need my bci
    [19, 25, 64],       # I need my technology
    [19, 21, 65],       # I like my computer
    [19, 21, 66],       # I like my brain
    [19, 21, 67],       # I like my eeg
    [19, 21, 68],       # I like my bci
    [19, 21, 69],       # I like my technology
    [19, 26, 70],       # I am a computer
    [19, 26, 71],       # I am a brain
    [19, 26, 72],       # I am a eeg
    [19, 26, 73],       # I am a bci
    [19, 26, 74],       # I am a technology
    [19, 25, 75],       # I need help with my computer
    [19, 25, 76],       # I need help with my brain
    [19, 25, 77],       # I need help with my eeg
    [19, 25, 78],       # I need help with my bci
    [19, 25, 79],       # I need help with my technology
    [19, 21, 80]        # I like to use my computer
]
# ------------------------------
# 5 New Target Sentences (for testing)
new_target_sentences = [
    [19, 25, 2, 3],        # "I need food water"
    [19, 26, 16, 17],      # "I am hurt cold"
    [19, 25, 11, 2],       # "I need love food"
    [19, 21, 22, 3],       # "I like good water"
    [20, 25, 1, 24],       # "you need help now"
    [19, 25, 27],         # "I need my medicine"
    [19, 25, 30],         # "I need my book"
    [19, 21, 28],         # "I like my friend"
    [19, 26, 40],         # "I am hungry"
    [19, 26, 42],         # "I am tired"
    [19, 25, 50],         # "I need help now"
    [19, 25, 52],         # "I need to sleep now"
    [19, 25, 54],         # "I need to eat now"
    [19, 25, 55],         # "I need to drink now"
    [19, 25, 56],         # "I need to take my medicine now"
    [19, 25, 58],         # "I need my family"
    [19, 25, 60],         # "I need my computer"
    [19, 21, 65],         # "I like my computer"
    [19, 26, 70],         # "I am a computer"
    [19, 25, 75],         # "I need help with my computer"
    [19, 25, 80],         # "I want to use my computer"
    [19, 25, 90],         # "I want to learn more about computers"
    [19, 25, 100],        # "I want to control my computer with my brain"
    [19, 25, 108],        # "I want to type with my brain"
]

# ------------------------------
# EEGNet-inspired feature extractor (output sequence length = 32)
class EEGNetFeatureExtractor(nn.Module):
    def __init__(self, Chans=16, Samples=256, dropoutRate=0.2, kernLength=64, F1=8, D=2, F2=16):
        super(EEGNetFeatureExtractor, self).__init__()
        self.firstconv = nn.Sequential(
            nn.Conv2d(1, F1, kernel_size=(1, kernLength), padding=(0, kernLength//2), bias=False),
            nn.BatchNorm2d(F1, affine=False)
        )
        self.depthwiseConv = nn.Sequential(
            nn.Conv2d(F1, F1*D, kernel_size=(Chans, 1), groups=F1, bias=False),
            nn.BatchNorm2d(F1*D, affine=False),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, 4)),  # 256 -> 64
            nn.Dropout(dropoutRate)
        )
        self.separableConv = nn.Sequential(
            nn.Conv2d(F1*D, F2, kernel_size=(1, 8), padding=(0, 8//2), bias=False),
            nn.BatchNorm2d(F2, affine=False),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, 2)),  # 64 -> 32
            nn.Dropout(dropoutRate)
        )
    def forward(self, x):
        x = x.unsqueeze(1)  # (batch, 1, Chans, Samples)
        x = self.firstconv(x)
        x = self.depthwiseConv(x)
        x = self.separableConv(x)  # (batch, F2, 1, 32)
        x = x.squeeze(2)          # (batch, F2, 32)
        x = x.permute(0, 2, 1)      # (batch, 32, F2)
        return x

# ------------------------------
# Sequence model: EEG -> LSTM -> Classifier
class EEGToTextModel(nn.Module):
    def __init__(self, hidden_dim=128, vocab_size=VOCAB_SIZE, num_layers=2):
        super(EEGToTextModel, self).__init__()
        self.feature_extractor = EEGNetFeatureExtractor(Chans=16, Samples=256)
        self.lstm = nn.LSTM(input_size=16, hidden_size=hidden_dim,
                            num_layers=num_layers, batch_first=True, bidirectional=True)
        self.classifier = nn.Linear(hidden_dim*2, vocab_size)
    def forward(self, eeg_input):
        features = self.feature_extractor(eeg_input)  # (batch, 32, 16)
        lstm_out, _ = self.lstm(features)              # (batch, 32, hidden_dim*2)
        logits = self.classifier(lstm_out)             # (batch, 32, vocab_size)
        return logits

# ------------------------------
# Modified epsilon-greedy decoder with explicit max_len.
def epsilon_greedy_decode(logits, epsilon=0.0, target_seq=None, max_len=None, repetition_limit=2):
    # logits: (1, seq_len, vocab_size)
    probs = torch.softmax(logits, dim=-1)
    seq_len = probs.size(1)
    if max_len is None:
        if target_seq is not None:
            max_len = len(target_seq)
        else:
            max_len = seq_len
    decoded = []
    last_token = None
    repeat_count = 0
    for t in range(seq_len):
        if random.random() < epsilon:
            token = random.randint(1, VOCAB_SIZE - 1)
        else:
            token = torch.argmax(probs[0, t]).item()
        if token == EOS_TOKEN:
            break
        if token == last_token:
            repeat_count += 1
            if repeat_count >= repetition_limit:
                break
        else:
            repeat_count = 1
        decoded.append(token)
        last_token = token
        if len(decoded) >= max_len:
            break
    return decoded

# ------------------------------
# EEG Dataset using coherent fragments.
class EEGDataset(Dataset):
    def __init__(self, num_samples=200):
        self.num_samples = num_samples
        self.channels = 16
        self.time_samples = 256
        self.data = []
        self.labels = []
        np.random.seed(42)
        # Fixed EEG pattern for each token.
        self.token_patterns = {
            token: torch.tensor(np.random.randn(self.channels, self.time_samples//4), dtype=torch.float).to(device)
            for token in range(1, VOCAB_SIZE)
        }
        for _ in range(num_samples):
            label_seq = random.choice(coherent_fragments)
            self.labels.append(torch.tensor(label_seq, dtype=torch.long).to(device))
            eeg = torch.randn(self.channels, self.time_samples, device=device) * 0.1
            seq_length = len(label_seq)
            seg_len = self.time_samples // seq_length
            for i, token in enumerate(label_seq):
                start = i * seg_len
                end = start + seg_len
                pattern = self.token_patterns[token]
                if pattern.shape[1] != seg_len:
                    pattern = F.interpolate(pattern.unsqueeze(0), size=seg_len, mode='linear', align_corners=False).squeeze(0)
                eeg[:, start:end] += pattern
            self.data.append(eeg)
    def __len__(self):
        return self.num_samples
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

def collate_fn(batch):
    eegs, label_seqs = zip(*batch)
    eegs = torch.stack(eegs)
    return eegs, label_seqs

# ------------------------------
# Function to synthesize EEG for testing new sentences.
def synthesize_eeg(label_seq, token_patterns, channels=16, time_samples=256):
    eeg = torch.randn(channels, time_samples, device=device) * 0.1
    seq_length = len(label_seq)
    seg_len = time_samples // seq_length
    for i, token in enumerate(label_seq):
        start = i * seg_len
        end = start + seg_len
        pattern = token_patterns[token]
        if pattern.shape[1] != seg_len:
            pattern = F.interpolate(pattern.unsqueeze(0), size=seg_len, mode='linear', align_corners=False).squeeze(0)
        eeg[:, start:end] += pattern
    return eeg

# ------------------------------
# Training loop with online update and full logging.
def train_model(model, dataloader, optimizer, ce_loss_fn, num_epochs=50, teacher_forcing_prob=0.8, scheduled_decay=0.98, epsilon=0.3, online_lr=1e-4):
    total_samples = len(dataloader.dataset)
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        total_tokens = 0
        correct_count = 0
        for eegs, label_seqs in dataloader:
            eegs = eegs.to(device)
            optimizer.zero_grad()
            logits = model(eegs)  # (batch, 32, vocab_size)
            batch_targets = []
            for seq in label_seqs:
                seq_list = seq.tolist()
                padded = seq_list + [EOS_TOKEN]*(32 - len(seq_list))
                batch_targets.append(padded)
            targets = torch.tensor(batch_targets, dtype=torch.long, device=device)
            loss = ce_loss_fn(logits.view(-1, VOCAB_SIZE), targets.view(-1))
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * logits.size(0) * logits.size(1)
            total_tokens += logits.size(0) * logits.size(1)

            # For each sample in the batch, check prediction and if wrong, perform an online update.
            for i, true_seq in enumerate(label_seqs):
                true_list = true_seq.tolist()
                decoded = epsilon_greedy_decode(logits[i:i+1], epsilon=epsilon, target_seq=true_list, max_len=len(true_list))
                if decoded == true_list:
                    correct_count += 1
                else:
                    optimizer.zero_grad()
                    sample_logits = model(eegs[i:i+1])
                    sample_target = torch.tensor([true_list + [EOS_TOKEN]*(32 - len(true_list))], dtype=torch.long, device=device)
                    sample_loss = ce_loss_fn(sample_logits.view(-1, VOCAB_SIZE), sample_target.view(-1))
                    orig_lr = optimizer.param_groups[0]['lr']
                    for group in optimizer.param_groups:
                        group['lr'] = online_lr
                    sample_loss.backward()
                    optimizer.step()
                    for group in optimizer.param_groups:
                        group['lr'] = orig_lr

        avg_loss = total_loss / total_tokens
        exact_match_acc = correct_count / total_samples * 100.0
        print(f"Epoch {epoch+1}: Avg Loss = {avg_loss:.4f}, Exact Match Accuracy = {exact_match_acc:.2f}%")
        
        # Full sample predictions logging.
        model.eval()
        with torch.no_grad():
            print(f"Epoch {epoch+1} full sample predictions:")
            for idx in range(total_samples):
                eeg_sample, label_seq = dataset[idx]
                eeg_sample = eeg_sample.unsqueeze(0).to(device)
                sample_logits = model(eeg_sample)
                pred_tokens = epsilon_greedy_decode(sample_logits, epsilon=0.0, target_seq=label_seq.tolist(), max_len=len(label_seq))
                pred_sentence = " ".join([VOCAB[token] for token in pred_tokens])
                true_sentence = " ".join([VOCAB[token] for token in label_seq.tolist()])
                print(f"  Sample {idx+1}: Prediction: [{pred_sentence}] | True: [{true_sentence}]")
        model.train()
        teacher_forcing_prob *= scheduled_decay
    print("Training finished.")

# ------------------------------
# Dynamic checkpoint loader to handle vocabulary size changes.
def load_checkpoint_dynamic(model, optimizer, filename="model_checkpoint.pth"):
    if os.path.exists(filename):
        print("Loading checkpoint...")
        checkpoint = torch.load(filename, map_location=device, weights_only=True)
        model_state = checkpoint['model_state_dict']
        mismatch = False
        if model_state['classifier.weight'].shape[0] != model.classifier.weight.shape[0]:
            print("Vocabulary size mismatch detected. Reinitializing classifier layer...")
            mismatch = True
            model_state.pop('classifier.weight')
            model_state.pop('classifier.bias')
        model.load_state_dict(model_state, strict=False)
        if not mismatch:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        else:
            print("Optimizer state not loaded due to classifier mismatch.")
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch}.")
        return start_epoch
    return 0

# ------------------------------
# Save function for persistence.
def save_checkpoint(model, optimizer, epoch, filename="model_checkpoint.pth"):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }
    torch.save(checkpoint, filename)
    print(f"Checkpoint saved at epoch {epoch}.")

# ------------------------------
# Main entry point.
def main():
    global dataset
    dataset = EEGDataset(num_samples=200)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
    # Instantiate the model with the dynamic vocabulary size.
    model = EEGToTextModel(hidden_dim=128, vocab_size=len(VOCAB), num_layers=2).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    ce_loss_fn = nn.CrossEntropyLoss(ignore_index=EOS_TOKEN)

    checkpoint_file = "model_checkpoint.pth"
    # Use the dynamic checkpoint loader.
    start_epoch = load_checkpoint_dynamic(model, optimizer, filename=checkpoint_file)

    # Train for total_epochs (or resume from checkpoint)
    total_epochs = 380
    for epoch in range(start_epoch, total_epochs):
        train_model(model, dataloader, optimizer, ce_loss_fn, num_epochs=1,
                    teacher_forcing_prob=0.8, scheduled_decay=0.98, epsilon=0.3, online_lr=1e-4)
        save_checkpoint(model, optimizer, epoch, filename=checkpoint_file)

    print("\nTest Predictions on New Target Sentences:")
    model.eval()
    with torch.no_grad():
        for target in new_target_sentences:
            test_eeg = synthesize_eeg(target, dataset.token_patterns, channels=16, time_samples=256)
            test_eeg = test_eeg.unsqueeze(0)
            test_logits = model(test_eeg)
            pred_tokens = epsilon_greedy_decode(test_logits, epsilon=0.0, target_seq=target, max_len=len(target))
            pred_sentence = " ".join([VOCAB[token] for token in pred_tokens])
            true_sentence = " ".join([VOCAB[token] for token in target])
            print(f"Prediction: [{pred_sentence}] | True: [{true_sentence}]")

if __name__ == "__main__":
    main()


Using device: cuda
Loading checkpoint...
Resuming training from epoch 375.
Epoch 1: Avg Loss = 0.1815, Exact Match Accuracy = 24.00%
Epoch 1 full sample predictions:
  Sample 1: Prediction: [I am work] | True: [I am work]
  Sample 2: Prediction: [I am good] | True: [I am good]
  Sample 3: Prediction: [I need and] | True: [I need and]
  Sample 4: Prediction: [I am would] | True: [I am would]
  Sample 5: Prediction: [I need love food] | True: [I need love food]
  Sample 6: Prediction: [I am work] | True: [I am work]
  Sample 7: Prediction: [I need a] | True: [I need a]
  Sample 8: Prediction: [I am cold] | True: [I am cold]
  Sample 9: Prediction: [I like my] | True: [I like my]
  Sample 10: Prediction: [I like thirsty] | True: [I like thirsty]
  Sample 11: Prediction: [I need morning] | True: [I need morning]
  Sample 12: Prediction: [I like my] | True: [I like my]
  Sample 13: Prediction: [I am worse] | True: [I am worse]
  Sample 14: Prediction: [I am better] | True: [I am better]
  S