In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from transformers import (
    BertModel, BertTokenizer,
    RobertaModel, RobertaTokenizer,
    ElectraModel, ElectraTokenizer
)
from peft import LoraConfig, get_peft_model, TaskType
from river import drift
import numpy as np
from collections import deque
from typing import List, Tuple, Dict
import warnings
warnings.filterwarnings('ignore')


class LoRATransformerWrapper:
    """Wrapper para modelos Transformer com LoRA aplicado"""

    def __init__(self, model_name: str, num_labels: int = 2, rank: int = 8, alpha: int = 16):
        self.model_name = model_name
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Carrega modelo e tokenizer
        if 'bert-base' in model_name:
            self.tokenizer = BertTokenizer.from_pretrained(model_name)
            self.base_model = BertModel.from_pretrained(model_name)
        elif 'roberta' in model_name:
            self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
            self.base_model = RobertaModel.from_pretrained(model_name)
        elif 'electra' in model_name:
            self.tokenizer = ElectraTokenizer.from_pretrained(model_name)
            self.base_model = ElectraModel.from_pretrained(model_name)

        # Configura√ß√£o LoRA
        lora_config = LoraConfig(
            task_type=TaskType.FEATURE_EXTRACTION,
            r=rank,
            lora_alpha=alpha,
            lora_dropout=0.1,
            target_modules=["query", "value"]  # Aplica LoRA nas camadas Q e V
        )

        # Aplica LoRA ao modelo
        self.model = get_peft_model(self.base_model, lora_config)
        self.model.to(self.device)

        # Classifier head
        hidden_size = self.base_model.config.hidden_size
        self.classifier = nn.Linear(hidden_size, num_labels).to(self.device)

        print(f"‚úì {model_name} carregado com LoRA (params trein√°veis: {self.model.print_trainable_parameters()})")

    def encode(self, texts: List[str]) -> torch.Tensor:
        """Gera embeddings para uma lista de textos"""
        self.model.eval()

        # Tokeniza
        encoded = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors='pt'
        ).to(self.device)

        # Gera embeddings
        with torch.no_grad():
            outputs = self.model(**encoded)
            # Usa [CLS] token embedding
            embeddings = outputs.last_hidden_state[:, 0, :]

        return embeddings

    def fine_tune(self, buffer: List[Tuple[str, int]], epochs: int = 3, lr: float = 3e-4):
        """Fine-tune incremental com LoRA"""
        self.model.train()
        self.classifier.train()

        # Prepara dados
        texts = [x for x, y in buffer]
        labels = torch.tensor([y for x, y in buffer], dtype=torch.long).to(self.device)

        encoded = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors='pt'
        ).to(self.device)

        # Dataset e DataLoader
        dataset = TensorDataset(
            encoded['input_ids'],
            encoded['attention_mask'],
            labels
        )
        dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

        # Otimizador (apenas params LoRA + classifier)
        optimizer = torch.optim.AdamW(
            list(self.model.parameters()) + list(self.classifier.parameters()),
            lr=lr
        )
        criterion = nn.CrossEntropyLoss()

        # Treinamento
        for epoch in range(epochs):
            total_loss = 0
            for batch in dataloader:
                input_ids, attention_mask, batch_labels = batch

                optimizer.zero_grad()

                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
                embeddings = outputs.last_hidden_state[:, 0, :]
                logits = self.classifier(embeddings)

                loss = criterion(logits, batch_labels)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            avg_loss = total_loss / len(dataloader)
            print(f"  Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")


class MLPMetaLearner(nn.Module):
    """Meta-learner MLP que combina embeddings dos 3 modelos"""

    def __init__(self, input_dim: int = 768*3, hidden_dim: int = 256, num_labels: int = 2):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_labels)
        )
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, x):
        return self.network(x)

    def train_mlp(self, embeddings: torch.Tensor, labels: torch.Tensor,
                  epochs: int = 5, lr: float = 1e-3):
        """Treina o meta-learner"""
        self.train()

        dataset = TensorDataset(embeddings, labels)
        dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        for epoch in range(epochs):
            total_loss = 0
            for batch_emb, batch_labels in dataloader:
                batch_emb = batch_emb.to(self.device)
                batch_labels = batch_labels.to(self.device)

                optimizer.zero_grad()
                logits = self(batch_emb)
                loss = criterion(logits, batch_labels)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            print(f"  MLP Epoch {epoch+1}/{epochs} - Loss: {total_loss/len(dataloader):.4f}")


class IncrementalMetaLearnerWithLoRA:
    """Sistema completo de aprendizado incremental com LoRA e detec√ß√£o de drift"""

    def __init__(self, buffer_size: int = 100, drift_threshold: float = 0.001):
        self.buffer_size = buffer_size
        self.buffer = deque(maxlen=buffer_size)

        # Detectores de drift (um por modelo para melhor detec√ß√£o)
        self.drift_detectors = {
            'bert': drift.ADWIN(delta=drift_threshold),
            'roberta': drift.ADWIN(delta=drift_threshold),
            'electra': drift.ADWIN(delta=drift_threshold)
        }

        # Modelos base com LoRA
        print("Inicializando modelos base com LoRA...")
        self.models = {
            'bert': LoRATransformerWrapper('bert-base-uncased'),
            'roberta': LoRATransformerWrapper('roberta-base'),
            'electra': LoRATransformerWrapper('google/electra-base-discriminator')
        }

        # Meta-learner
        self.meta_learner = MLPMetaLearner(input_dim=768*3, num_labels=2)

        self.is_initialized = False
        self.total_samples = 0
        self.drift_count = 0

    def _extract_combined_embeddings(self, texts: List[str]) -> torch.Tensor:
        """Extrai e concatena embeddings dos 3 modelos"""
        emb_bert = self.models['bert'].encode(texts)
        emb_roberta = self.models['roberta'].encode(texts)
        emb_electra = self.models['electra'].encode(texts)

        # Concatena embeddings
        combined = torch.cat([emb_bert, emb_roberta, emb_electra], dim=1)
        return combined

    def _detect_drift(self) -> bool:
        """Detecta concept drift usando m√∫ltiplos detectores"""
        if len(self.buffer) < 30:  # M√≠nimo de amostras
            return False

        # Calcula acur√°cia dos √∫ltimos exemplos
        recent_texts = [x for x, y in list(self.buffer)[-30:]]
        recent_labels = torch.tensor([y for x, y in list(self.buffer)[-30:]])

        # Predi√ß√£o
        embeddings = self._extract_combined_embeddings(recent_texts)
        self.meta_learner.eval()
        with torch.no_grad():
            logits = self.meta_learner(embeddings)
            predictions = torch.argmax(logits, dim=1).cpu()

        # Calcula erro
        errors = (predictions != recent_labels).float()

        # Atualiza detectores
        drift_detected = False
        for i, error in enumerate(errors):
            for detector in self.drift_detectors.values():
                detector.update(error.item())
                if detector.drift_detected:
                    drift_detected = True

        return drift_detected

    def learn_one(self, text: str, label: int):
        """Aprende incrementalmente com uma √∫nica inst√¢ncia"""
        self.buffer.append((text, label))
        self.total_samples += 1

        # Inicializa√ß√£o: aguarda buffer cheio pela primeira vez
        if not self.is_initialized and len(self.buffer) == self.buffer_size:
            print(f"\n=== Inicializa√ß√£o com {self.buffer_size} amostras ===")
            self._initial_training()
            self.is_initialized = True
            return

        # Ap√≥s inicializa√ß√£o: verifica drift quando buffer enche
        if self.is_initialized and len(self.buffer) == self.buffer_size:
            drift_detected = self._detect_drift()

            if drift_detected:
                self.drift_count += 1
                print(f"\nüîÑ DRIFT DETECTADO (#{self.drift_count}) - Atualizando modelos...")
                self._incremental_update()

    def _initial_training(self):
        """Treinamento inicial com o primeiro buffer"""
        buffer_list = list(self.buffer)

        # Fine-tune modelos base com LoRA
        print("\n1. Fine-tuning modelos base com LoRA...")
        for name, model in self.models.items():
            print(f"\n  Fine-tuning {name.upper()}...")
            model.fine_tune(buffer_list, epochs=3)

        # Extrai embeddings
        print("\n2. Extraindo embeddings combinados...")
        texts = [x for x, y in buffer_list]
        labels = torch.tensor([y for x, y in buffer_list], dtype=torch.long)

        embeddings = self._extract_combined_embeddings(texts)

        # Treina meta-learner
        print("\n3. Treinando Meta-learner MLP...")
        self.meta_learner.train_mlp(embeddings, labels, epochs=5)

        print("\n‚úì Inicializa√ß√£o completa!")

    def _incremental_update(self):
        """Atualiza√ß√£o incremental ap√≥s detec√ß√£o de drift"""
        buffer_list = list(self.buffer)

        # Fine-tune incremental com LoRA (r√°pido!)
        print("  Atualizando modelos com LoRA...")
        for name, model in self.models.items():
            print(f"    {name.upper()}...")
            model.fine_tune(buffer_list, epochs=2, lr=1e-4)  # Menos √©pocas, LR menor

        # Atualiza meta-learner
        print("  Atualizando Meta-learner...")
        texts = [x for x, y in buffer_list]
        labels = torch.tensor([y for x, y in buffer_list], dtype=torch.long)
        embeddings = self._extract_combined_embeddings(texts)

        self.meta_learner.train_mlp(embeddings, labels, epochs=3, lr=5e-4)

        print("  ‚úì Atualiza√ß√£o completa!")

        # Limpa buffer ap√≥s update
        self.buffer.clear()

    def predict_one(self, text: str) -> int:
        """Prediz a classe de uma √∫nica inst√¢ncia"""
        if not self.is_initialized:
            return None

        self.meta_learner.eval()
        with torch.no_grad():
            embeddings = self._extract_combined_embeddings([text])
            logits = self.meta_learner(embeddings)
            prediction = torch.argmax(logits, dim=1).item()

        return prediction

    def get_stats(self) -> Dict:
        """Retorna estat√≠sticas do sistema"""
        return {
            'total_samples': self.total_samples,
            'drift_count': self.drift_count,
            'buffer_size': len(self.buffer),
            'is_initialized': self.is_initialized
        }


# ========================
# EXEMPLO DE USO COM STREAM
# ========================

from river import datasets

dataset = datasets.SMSSpam()

system = IncrementalMetaLearnerWithLoRA(buffer_size=100, drift_threshold=0.001)

correct = 0
total = 0

print("Iniciando aprendizado incremental com SMS Spam (River)...\n")

for i, (x, y) in enumerate(dataset):
    text = x['body']
    label = int(y)  # False=0, True=1

    # Predi√ß√£o
    prediction = system.predict_one(text)

    if prediction is not None:
        total += 1
        if prediction == label:
            correct += 1

    # Aprendizado incremental
    system.learn_one(text, label)

    # Log peri√≥dico
    if (i + 1) % 100 == 0:
        stats = system.get_stats()
        acc = (correct / total * 100) if total > 0 else 0
        print(f"\nüìä Progresso: {i+1} mensagens")
        print(f"   Acur√°cia online: {acc:.2f}%")
        print(f"   Drifts detectados: {stats['drift_count']}")

print("\n" + "="*50)
print("ESTAT√çSTICAS FINAIS")
print("="*50)
stats = system.get_stats()
print(f"Total de amostras: {stats['total_samples']}")
print(f"Drifts detectados: {stats['drift_count']}")
print(f"Acur√°cia final: {correct/total*100:.2f}%")
