'''
Author:
        
        PARK, JunHo, junho@ccnets.org

        
        KIM, JoengYoong, jeongyoong@ccnets.org
        
    COPYRIGHT (c) 2024. CCNets. All Rights reserved.
'''

# Is it Possible to Learn Target Prediction from Encoding? 
## Recyclable and Household Waste Classification

## Overview

This notebook demonstrates the use of a cooperative encoding
network for causal encoding. The pipeline involves an encoding
model (cooperative encoding network with three ResNets) and
a core model (three GPTs). It showcases how the encoding
process and prediction learning happen internally in the API,
as illustrated in the example file.

## Problem Definition

In many machine learning tasks, effectively encoding data so that
the predictive model can learn the underlying patterns and
relationships is challenging. Traditional encoding methods, such
as autoencoders, often fail to capture the causal relationships
and inherent attributes of the data effectively. These methods
typically compress the data stochastically, requiring a decoding
step with the paired decoder used in training to retrieve the
original data before further prediction learning can occur.

### Key Issues with Traditional Encoding:
1. **Need for Decoding:** Further prediction learning from
   compressed data (e.g., large language model learning)
   necessitates a decoder.
2. **Paired Decoding:** The decoder must be trained together
   with the encoder to effectively reconstruct the original data.

In contrast, the causal encoding framework directly addresses these
issues by uncovering and manipulating independent causal factors
and common attributes in dataset observations. This allows for
a more structured and meaningful representation of the data,
facilitating accurate prediction learning without the need
for intermediate decoding steps.

### Advantages of Causal Encoding Framework:
- **Independent of Prediction Model:** The cooperative encoding
  network does not require training or interaction with the
  prediction model.
- **Structured Representation:** Utilizes both stochastic and
  deterministic variables to capture all causal factors in the
  latent representation, which is a concatenation of the Explainer
  and Reasoner outputs in the encoding network.

This structured approach allows the model to learn from the encoded
data and make accurate predictions more efficiently.

### Dataset

[Recyclable and Household Waste Classification](https://www.kaggle.com/datasets/alistairking/recyclable-and-household-waste-classification)

In [1]:
import sys

path_append = "../" # Go up one directory from where you are.
sys.path.append(path_append) 

from tools.setting.ml_params import MLParameters
from tools.setting.data_config import DataConfig
from nn.utils.init import set_random_seed
set_random_seed(0)

In [2]:
dataset_path = '../data/Recyclable and Household Waste Classification/images/images'

In [3]:
import os
import torch
import torch.nn.functional as F
import random
from PIL import Image
from torch.utils.data import Dataset
def set_random_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def split_images(root_dir, split_ratio=(0.6, 0.2, 0.2)):
    classes = sorted(os.listdir(root_dir))
    train_paths, val_paths, test_paths = [], [], []

    for class_name in classes:
        class_dir = os.path.join(root_dir, class_name)
        for subfolder in ['default', 'real_world']:
            subfolder_dir = os.path.join(class_dir, subfolder)
            if os.path.exists(subfolder_dir):
                image_names = os.listdir(subfolder_dir)
                random.shuffle(image_names)
                num_images = len(image_names)
                train_split = int(split_ratio[0] * num_images)
                val_split = int(split_ratio[1] * num_images) + train_split

                train_paths += [(os.path.join(subfolder_dir, img), class_name) for img in image_names[:train_split]]
                val_paths += [(os.path.join(subfolder_dir, img), class_name) for img in image_names[train_split:val_split]]
                test_paths += [(os.path.join(subfolder_dir, img), class_name) for img in image_names[val_split:]]
    
    return train_paths, val_paths, test_paths

class BaseDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def set_seed(self, seed):
        self.seed = seed
        set_random_seed(self.seed)
        random.shuffle(self.image_paths)

    def __len__(self):
        return len(self.image_paths)

    def _load_image(self, image_path):
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

class EncoderDataset(BaseDataset):
    def __init__(self, image_paths, transform=None):
        super().__init__(image_paths, transform)

    def __getitem__(self, index):
        image_path, _ = self.image_paths[index]
        image = self._load_image(image_path)
        return image, None

class CoreDataset(BaseDataset):
    def __init__(self, encoder, image_paths, classes, transform=None, device='cuda', precompute_batches=64):

        self.labels = [classes.index(cls) for _, cls in image_paths]

        self.encoder = encoder
        self.device = device
        self.precompute_batches = precompute_batches
        self.X_cache = None
        self.y_cache = None
        self.total_iters = 0
        self.batch_indices = []
        self.seed = 0
        super().__init__(image_paths, transform)
        self.dataset_length = len(self.image_paths)
        self._shuffle_indices()
        self._precompute_batches(0)

    def _shuffle_indices(self):
        set_random_seed(self.seed + self.total_iters)
        self.batch_indices = torch.randperm(self.dataset_length)

    def _precompute_batches(self, start_idx):
        end_idx = min(start_idx + self.precompute_batches, self.dataset_length)
        batch_indices = self.batch_indices[start_idx:end_idx].tolist()
        
        images = [self._load_image(self.image_paths[i][0]) for i in batch_indices]
        images = torch.stack(images).to(self.device)
        
        labels = torch.tensor([self.labels[i] for i in batch_indices], dtype=torch.long).unsqueeze(-1)
        
        codes = self.encoder.encode(images)
        self.X_cache = codes
        self.y_cache = labels

    def __getitem__(self, idx):
        batch_idx = idx // self.precompute_batches
        batch_start_idx = batch_idx * self.precompute_batches
        cur_idx = idx % self.precompute_batches
        
        if self.total_iters % self.dataset_length == 0:
            self._shuffle_indices()
        if self.total_iters % self.precompute_batches == 0:
            self._precompute_batches(batch_start_idx)
        
        self.total_iters += 1
        
        if cur_idx >= len(self.X_cache):
            cur_idx = idx % len(self.X_cache)
        return self.X_cache[cur_idx], self.y_cache[cur_idx]

In [4]:
from torchvision import transforms

# Create the datasets and data loaders
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

num_classes = 30

In [5]:

data_config = DataConfig(dataset_name = 'recycle_image', task_type='multi_class_classification', obs_shape=[3, 128, 128], label_size=num_classes)

# Set the training parameters for the `encoder model`
encoder_params = MLParameters(core_model = 'none', encoder_model = 'resnet')

# Set the training parameters for the `core model`
core_params = MLParameters(core_model = 'gpt', encoder_model = 'none')

encoder_params.training.num_epoch = 1
core_params.training.num_epoch = 1
encoder_params.encoder_config.num_layers = 5
encoder_params.encoder_config.d_model = 256

In [6]:
from trainer_hub import TrainerHub

# Set the device to GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

# encoder_hub is for the encoder model
encoder_hub = TrainerHub(encoder_params, data_config, device, use_print=True, use_wandb=False)

# core_hub is for the core model
core_hub = TrainerHub(core_params, data_config, device, use_print=True, use_wandb=False)



In [7]:
# To use function `encode_inputs` in `trainer_hub1.helper` 
encoder = encoder_hub.encoder_ccnet

root_dir = path_append + dataset_path
classes = sorted(os.listdir(root_dir))
train_paths, val_paths, test_paths = split_images(root_dir)

# Create datasets
encoder_train_dataset = EncoderDataset(train_paths, transform=transform)

core_train_dataset = CoreDataset(encoder, train_paths, classes, transform=transform)
core_val_dataset = CoreDataset(encoder, val_paths, classes, transform=transform)
test_val_dataset = CoreDataset(encoder, test_paths, classes, transform=transform)

In [8]:
for i in range(10):
    print("="*10,"Encoder Epoch", i,"="*10)
    # Train the encoder models with the encoder dataset
    encoder_train_dataset.set_seed(i)
    encoder_hub.train(encoder_train_dataset)
    
    print("="*10,"Core Epoch", i,"="*10)
    # Train the core models with encoded inputs
    core_train_dataset.set_seed(i)
    core_hub.train(core_train_dataset, core_val_dataset)



Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

Iterations:   0%|          | 0/140 [00:00<?, ?it/s]

[0/1][50/140][Time 16.94]
Unified LR across all optimizers: 0.0001995308238189185
--------------------Training Metrics--------------------
Cooperative Network(encoder):  Three Resnet
Inf: 1.2369	Gen: 1.1362	Rec: 1.3511	E: 1.0219	R: 1.4519	P: 1.2504
[0/1][100/140][Time 16.42]
Unified LR across all optimizers: 0.00019907191565870155
--------------------Training Metrics--------------------
Cooperative Network(encoder):  Three Resnet
Inf: 1.1123	Gen: 1.0689	Rec: 1.2009	E: 0.9803	R: 1.2443	P: 1.1574


Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

Iterations:   0%|          | 0/140 [00:00<?, ?it/s]

[0/1][50/140][Time 11.49]
Unified LR across all optimizers: 0.0001995308238189185
--------------------Training Metrics--------------------
Cooperative Network(core):  Three Gpt
Inf: 0.2921	Gen: 0.5550	Rec: 0.4717	E: 0.3754	R: 0.2088	P: 0.7345
--------------------Test Metrics------------------------
accuracy: 0.0312
precision: 0.0027
recall: 0.0222
f1_score: 0.0048

[0/1][100/140][Time 11.44]
Unified LR across all optimizers: 0.00019907191565870155
--------------------Training Metrics--------------------
Cooperative Network(core):  Three Gpt
Inf: 0.0816	Gen: 0.2577	Rec: 0.2407	E: 0.0986	R: 0.0645	P: 0.4168
--------------------Test Metrics------------------------
accuracy: 0.0664
precision: 0.0128
recall: 0.0492
f1_score: 0.0160



Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

Iterations:   0%|          | 0/140 [00:00<?, ?it/s]

[0/1][50/140][Time 63.03]
Unified LR across all optimizers: 0.00019824853909673957
--------------------Training Metrics--------------------
Cooperative Network(encoder):  Three Resnet
Inf: 1.9668	Gen: 1.8835	Rec: 2.0800	E: 1.7703	R: 2.1634	P: 1.9966
[0/1][100/140][Time 16.50]
Unified LR across all optimizers: 0.000197792580109545
--------------------Training Metrics--------------------
Cooperative Network(encoder):  Three Resnet
Inf: 1.0759	Gen: 1.0243	Rec: 1.1316	E: 0.9686	R: 1.1832	P: 1.0800


Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

Iterations:   0%|          | 0/140 [00:00<?, ?it/s]

[0/1][50/140][Time 68.99]
Unified LR across all optimizers: 0.00019824853909673957
--------------------Training Metrics--------------------
Cooperative Network(core):  Three Gpt
Inf: 0.1141	Gen: 0.6790	Rec: 0.6658	E: 0.1274	R: 0.1008	P: 1.2307
--------------------Test Metrics------------------------
accuracy: 0.0195
precision: 0.0007
recall: 0.0333
f1_score: 0.0013

[0/1][100/140][Time 11.36]
Unified LR across all optimizers: 0.000197792580109545
--------------------Training Metrics--------------------
Cooperative Network(core):  Three Gpt
Inf: 0.0551	Gen: 0.3333	Rec: 0.3281	E: 0.0603	R: 0.0499	P: 0.6063
--------------------Test Metrics------------------------
accuracy: 0.0195
precision: 0.0007
recall: 0.0333
f1_score: 0.0013



Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

Iterations:   0%|          | 0/140 [00:00<?, ?it/s]

[0/1][50/140][Time 63.95]
Unified LR across all optimizers: 0.00019697449497657537
--------------------Training Metrics--------------------
Cooperative Network(encoder):  Three Resnet
Inf: 1.9213	Gen: 1.8351	Rec: 2.0073	E: 1.7491	R: 2.0935	P: 1.9211
[0/1][100/140][Time 16.49]
Unified LR across all optimizers: 0.00019652146620954448
--------------------Training Metrics--------------------
Cooperative Network(encoder):  Three Resnet
Inf: 1.0847	Gen: 1.0232	Rec: 1.1261	E: 0.9818	R: 1.1876	P: 1.0646
