In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np
import copy
import matplotlib.pyplot as plt


################### SIX SHOT #########


# Definir as classes esperadas
classes = ['Pastos', 'planta_daninha', 'planta_toxicas']  

# Definir um Dataset personalizado para lidar com embeddings
class EmbeddingDataset(Dataset):
    def __init__(self, csv_file, num_classes=3):
        self.data = pd.read_csv(csv_file, header=0)  # Lê o CSV com o cabeçalho
        self.num_classes = num_classes
        self.class_to_idx = {classes[i]: i for i in range(num_classes)}
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        label_str, embedding_str = self.data.iloc[idx]
        label = self.class_to_idx[label_str]
        embedding = np.fromstring(embedding_str.strip('[]'), sep=' ')
        embedding = torch.tensor(embedding, dtype=torch.float32)
        return embedding, label

# Definir o DataLoader
csv_file = './embeddings_with_augmentation.csv'  
dataset = EmbeddingDataset(csv_file)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=12, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

# Redefinir a CNN para lidar com embeddings
class EmbeddingClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(EmbeddingClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, num_classes)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Obtenha o tamanho do embedding (input_dim) a partir de um exemplo do dataset
input_dim = dataset[0][0].shape[0]

# Instanciar o modelo
model = EmbeddingClassifier(input_dim=input_dim, num_classes=len(classes))

# Definir a função de perda e otimizador
loss_fn = nn.CrossEntropyLoss()
meta_optimizer = optim.Adam(model.parameters(), lr=0.001)

# Função para calcular a acurácia
def calculate_accuracy(outputs, labels):
    _, predicted = torch.max(outputs.data, 1)
    total = labels.size(0)
    correct = (predicted == labels).sum().item()
    return correct / total

# Inner loop para aprendizado six-shot (atualização específica da tarefa)
def inner_loop(support_inputs, support_labels, model, loss_fn, lr=0.01):
    temp_model = copy.deepcopy(model)
    temp_optimizer = optim.SGD(temp_model.parameters(), lr=lr)
    temp_optimizer.zero_grad()
    outputs = temp_model(support_inputs)
    loss = loss_fn(outputs, support_labels)
    loss.backward()
    temp_optimizer.step()
    return temp_model

# Outer loop (meta-atualização)
def outer_loop(train_loader, model, meta_optimizer, loss_fn, train_losses, train_accuracies, num_tasks=1):
    meta_optimizer.zero_grad()
    epoch_loss = 0
    epoch_accuracy = 0
    
    for task_data in train_loader:
        inputs, labels = task_data
        
        # Dividir os dados em support e query sets (six-shot)
        support_inputs, query_inputs = inputs[:len(inputs) - 2], inputs[len(inputs) - 2:]
        support_labels, query_labels = labels[:len(labels) - 2], labels[len(labels) - 2:]
        
        # Inner loop
        temp_model = inner_loop(support_inputs, support_labels, model, loss_fn)
        
        # Avaliar no conjunto de consultas (query set)
        query_outputs = temp_model(query_inputs)
        query_loss = loss_fn(query_outputs, query_labels)
        
        # Coletar perda e acurácia
        epoch_loss += query_loss.item()
        epoch_accuracy += calculate_accuracy(query_outputs, query_labels)
        
        # Retropropagar o meta-gradiente
        query_loss.backward()
    
    meta_optimizer.step()
    
    avg_train_loss = epoch_loss / len(train_loader)
    avg_train_accuracy = epoch_accuracy / len(train_loader)
    
    train_losses.append(avg_train_loss)
    train_accuracies.append(avg_train_accuracy)
    
    return avg_train_loss, avg_train_accuracy

# Avaliação do modelo
def evaluate_model(test_loader, test_losses, test_accuracies):
    model.eval()
    total_samples = 0
    correct_predictions = 0
    test_loss = 0

    with torch.no_grad():
        for data, labels in test_loader:
            outputs = model(data)
            loss = loss_fn(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()
    
    avg_test_loss = test_loss / len(test_loader)
    avg_test_accuracy = correct_predictions / total_samples
    test_losses.append(avg_test_loss)
    test_accuracies.append(avg_test_accuracy)

    return avg_test_loss, avg_test_accuracy

# Função para plotar as métricas individualmente
def plot_individual_metrics(epochs, values, label, title, ylabel):
    plt.figure(figsize=(8, 6))
    plt.plot(epochs, values, 'b-', label=label)
    plt.title(title)
    plt.xlabel('Épocas')
    plt.ylabel(ylabel)
    plt.legend()
    plt.grid(True)
    plt.show()

# Função para plotar todas as métricas separadamente
def plot_metrics_separately(train_losses, train_accuracies, test_losses, test_accuracies):
    epochs = range(1, len(train_losses) + 1)
    
    plot_individual_metrics(epochs, train_losses, 'Perda de Treinamento', 'Perda de Treinamento', 'Perda')
    plot_individual_metrics(epochs, test_losses, 'Perda de Teste', 'Perda de Teste', 'Perda')
    plot_individual_metrics(epochs, train_accuracies, 'Acurácia de Treinamento', 'Acurácia de Treinamento', 'Acurácia')
    plot_individual_metrics(epochs, test_accuracies, 'Acurácia de Teste', 'Acurácia de Teste', 'Acurácia')

# Treinamento com coleta de métricas
num_epochs = 25
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []

for epoch in range(num_epochs):
    avg_train_loss, avg_train_accuracy = outer_loop(train_loader, model, meta_optimizer, loss_fn, train_losses, train_accuracies)
    avg_test_loss, avg_test_accuracy = evaluate_model(test_loader, test_losses, test_accuracies)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_accuracy:.4f}, Test Loss: {avg_test_loss:.4f}, Test Acc: {avg_test_accuracy:.4f}")

# Plotar as métricas individualmente
plot_metrics_separately(train_losses, train_accuracies, test_losses, test_accuracies)

# Avaliação final do modelo
accuracy = 100 * test_accuracies[-1]
print(f"Accuracy on the new task or domain: {accuracy:.2f}%")
