'''
Author:
        
        PARK, JunHo, junho@ccnets.org

        
        KIM, JoengYoong, jeongyoong@ccnets.org
        
    COPYRIGHT (c) 2024. CCNets. All Rights reserved.
'''

# Is it Possible to Learn Target Prediction from Encoding? 
## Recyclable and Household Waste Classification

## Overview

This notebook demonstrates the use of a cooperative encoding
network for causal encoding. The pipeline involves an encoding
model (cooperative encoding network with three ResNets) and
a core model (three GPTs). It showcases how the encoding
process and prediction learning happen internally in the API,
as illustrated in the example file.

## Problem Definition

In many machine learning tasks, effectively encoding data so that
the predictive model can learn the underlying patterns and
relationships is challenging. Traditional encoding methods, such
as autoencoders, often fail to capture the causal relationships
and inherent attributes of the data effectively. These methods
typically compress the data stochastically, requiring a decoding
step with the paired decoder used in training to retrieve the
original data before further prediction learning can occur.

### Key Issues with Traditional Encoding:
1. **Need for Decoding:** Further prediction learning from
   compressed data (e.g., large language model learning)
   necessitates a decoder.
2. **Paired Decoding:** The decoder must be trained together
   with the encoder to effectively reconstruct the original data.

In contrast, the causal encoding framework directly addresses these
issues by uncovering and manipulating independent causal factors
and common attributes in dataset observations. This allows for
a more structured and meaningful representation of the data,
facilitating accurate prediction learning without the need
for intermediate decoding steps.

### Advantages of Causal Encoding Framework:
- **Independent of Prediction Model:** The cooperative encoding
  network does not require training or interaction with the
  prediction model.
- **Structured Representation:** Utilizes both stochastic and
  deterministic variables to capture all causal factors in the
  latent representation, which is a concatenation of the Explainer
  and Reasoner outputs in the encoding network.

This structured approach allows the model to learn from the encoded
data and make accurate predictions more efficiently.

### Dataset

[Recyclable and Household Waste Classification](https://www.kaggle.com/datasets/alistairking/recyclable-and-household-waste-classification)

In [None]:
import sys

path_append = "../" # Go up one directory from where you are.
sys.path.append(path_append) 

from tools.setting.ml_params import MLParameters
from tools.setting.data_config import DataConfig
from nn.utils.init import set_random_seed
set_random_seed(0)

In [None]:
dataset_path = '../data/Recyclable and Household Waste Classification/images/images'

In [None]:
import os
import torch
import torch.nn.functional as F
import random
from PIL import Image
from torch.utils.data import Dataset


class WasteDataset_for_encoder(Dataset):
    def __init__(self, root_dir, split, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.image_paths = []
        self.labels = []
        for i, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for subfolder in ['default', 'real_world']:
                subfolder_dir = os.path.join(class_dir, subfolder)
                if os.path.exists(subfolder_dir):
                    image_names = os.listdir(subfolder_dir)
                    random.shuffle(image_names)
                    for image_name in image_names:
                        self.image_paths.append(os.path.join(subfolder_dir, image_name))
                        self.labels.append(i)

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, index):
        image_path = self.image_paths[index]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, None
    
    
class WasteDataset_for_core(Dataset):
    def __init__(self, trainer_hub1, root_dir, split, transform=None, device='cuda', precompute_batches=64):
        self.encoder = trainer_hub1.helper
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.device = device
        self.image_paths = []
        self.labels = []
        self.precompute_batches = precompute_batches
        self.dataset_length = 0
        self.X_cache = None
        self.y_cache = None
        self.total_iters = 0
        self.batch_indices = []

        # Loop through each class and subfolder to load image paths and labels
        for i, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for subfolder in ['default', 'real_world']:
                subfolder_dir = os.path.join(class_dir, subfolder)
                if os.path.exists(subfolder_dir):
                    image_names = os.listdir(subfolder_dir)
                    random.shuffle(image_names)
                    image_names = self._split_images(image_names, split)
                    for image_name in image_names:
                        self.image_paths.append(os.path.join(subfolder_dir, image_name))
                        self.labels.append(i)
        
        self.dataset_length = len(self.image_paths)
        self._shuffle_indices()
        self._precompute_batches(0)

    # Helper function to split images into train, validation, or test sets
    def _split_images(self, image_names, split):
        if split == 'train':
            return image_names[:int(0.6 * len(image_names))]
        elif split == 'val':
            return image_names[int(0.6 * len(image_names)):int(0.8 * len(image_names))]
        else:
            return image_names[int(0.8 * len(image_names)):]
        
    # Shuffle the indices of the images to randomize batch creation
    def _shuffle_indices(self):
        self.batch_indices = torch.randperm(self.dataset_length)

    # Precompute batches by encoding a set of images
    def _precompute_batches(self, start_idx):
        end_idx = min(start_idx + self.precompute_batches, self.dataset_length)
        batch_indices = self.batch_indices[start_idx:end_idx].tolist()
        
        # Load and transform images
        images = [Image.open(self.image_paths[i]).convert('RGB') for i in batch_indices]
        if self.transform:
            images = [self.transform(img) for img in images]

        images = torch.stack(images).to(self.device)
        
        # Convert labels to tensors and one-hot encode
        labels = torch.tensor([self.labels[i] for i in batch_indices], dtype=torch.long)
        labels = F.one_hot(labels, num_classes=len(self.classes)).to(self.device)
        
        # Encode images using the encoder and store results in cache
        with torch.no_grad():
            codes, labels = self.encoder.encode_inputs(images, labels)  # Assuming the encoder method processes batched inputs
        
        # Store precomputed batch in cache
        self.X_cache = codes
        self.y_cache = labels

    def __len__(self):
        return self.dataset_length

    def __getitem__(self, idx):
        batch_idx = idx // self.precompute_batches
        batch_start_idx = batch_idx * self.precompute_batches
        cur_idx = idx % self.precompute_batches
        
        if self.total_iters % self.dataset_length == 0:
            self._shuffle_indices()
        if self.total_iters % self.precompute_batches == 0:
            self._precompute_batches(batch_start_idx)
        
        self.total_iters += 1
        
        if cur_idx >= len(self.X_cache):
            cur_idx = idx % len(self.X_cache)

        return self.X_cache[cur_idx], self.y_cache[cur_idx]

In [None]:
from torchvision import transforms

# Create the datasets and data loaders
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Dataset for encoder
train_dataset_for_encoder = WasteDataset_for_encoder(path_append + dataset_path, split='train', transform=transform)

In [None]:
data_config = DataConfig(dataset_name = 'recycle_image', task_type='multi_label_classification', obs_shape=[3, 224, 224], label_size=30)

#  Set training configuration from the MLParameters class, returning them as a Namespace object.

# Set the training parameters for the `encoder model`
ml_params1 = MLParameters(core_model = 'none', encoder_model = 'resnet')

# Set the training parameters for the `core model`
ml_params2 = MLParameters(core_model = 'gpt', encoder_model = 'none')

ml_params1.training.num_epoch = 1
ml_params2.training.num_epoch = 1

In [None]:
from trainer_hub import TrainerHub

# Set the device to GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

# Initialize 2 TrainerHub class with the training configuration, data configuration, device, and use_print and use_wandb flags

# trainer_hub1 is for the encoder model
trainer_hub1 = TrainerHub(ml_params1, data_config, device, use_print=True, use_wandb=False, print_interval=20)
# trainer_hub2 is for the core model
trainer_hub2 = TrainerHub(ml_params2, data_config, device, use_print=True, use_wandb=False, print_interval=20)

In [None]:
# To use function `encode_inputs` in `trainer_hub1.helper` 
encoder = trainer_hub1

# Split the dataset into train, validation, and test sets
train_dataset_for_core = WasteDataset_for_core(encoder, path_append + dataset_path, split='train', transform=transform)
val_dataset_for_core = WasteDataset_for_core(encoder, path_append + dataset_path, split='val', transform=transform)
test_dataset_for_core = WasteDataset_for_core(encoder, path_append + dataset_path, split='test', transform=transform)

In [10]:
for i in range(10):
    print("="*10,"Encoder Epoch", i,"="*10)
    # Train the encoder models with the encoder dataset
    trainer_hub1.train(train_dataset_for_encoder)
    
    print("="*10,"Core Epoch", i,"="*10)
    # Train the core models with encoded inputs
    trainer_hub2.train(train_dataset_for_core, val_dataset_for_core)