'''
Author:
        
        PARK, JunHo, junho@ccnets.org

        
        KIM, JoengYoong, jeongyoong@ccnets.org
        
    COPYRIGHT (c) 2024. CCNets. All Rights reserved.
'''

# Is it Possible to Learn Target Prediction from Encoding? 
## Recyclable and Household Waste Classification

## Overview

This notebook demonstrates the use of a cooperative encoding
network for causal encoding. The pipeline involves an encoding
model (cooperative encoding network with three ResNets) and
a core model (three GPTs). It showcases how the encoding
process and prediction learning happen internally in the API,
as illustrated in the example file.

## Problem Definition

In many machine learning tasks, effectively encoding data so that
the predictive model can learn the underlying patterns and
relationships is challenging. Traditional encoding methods, such
as autoencoders, often fail to capture the causal relationships
and inherent attributes of the data effectively. These methods
typically compress the data stochastically, requiring a decoding
step with the paired decoder used in training to retrieve the
original data before further prediction learning can occur.

### Key Issues with Traditional Encoding:
1. **Need for Decoding:** Further prediction learning from
   compressed data (e.g., large language model learning)
   necessitates a decoder.
2. **Paired Decoding:** The decoder must be trained together
   with the encoder to effectively reconstruct the original data.

In contrast, the causal encoding framework directly addresses these
issues by uncovering and manipulating independent causal factors
and common attributes in dataset observations. This allows for
a more structured and meaningful representation of the data,
facilitating accurate prediction learning without the need
for intermediate decoding steps.

### Advantages of Causal Encoding Framework:
- **Independent of Prediction Model:** The cooperative encoding
  network does not require training or interaction with the
  prediction model.
- **Structured Representation:** Utilizes both stochastic and
  deterministic variables to capture all causal factors in the
  latent representation, which is a concatenation of the Explainer
  and Reasoner outputs in the encoding network.

This structured approach allows the model to learn from the encoded
data and make accurate predictions more efficiently.

### Dataset

[Recyclable and Household Waste Classification](https://www.kaggle.com/datasets/alistairking/recyclable-and-household-waste-classification)

In [1]:
import sys

path_append = "../" # Go up one directory from where you are.
sys.path.append(path_append) 

from tools.setting.ml_params import MLParameters
from tools.setting.data_config import DataConfig
from nn.utils.init import set_random_seed
set_random_seed(0)

In [2]:
dataset_path = '../data/Recyclable and Household Waste Classification/images/images'

In [3]:
import os
import torch
import torch.nn.functional as F
import random
from PIL import Image
from torch.utils.data import Dataset


class WasteDataset(Dataset):
    def __init__(self, root_dir, split, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.image_paths = []
        self.labels = []
        
        for i, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for subfolder in ['default', 'real_world']:
                subfolder_dir = os.path.join(class_dir, subfolder)
                image_names = os.listdir(subfolder_dir)
                random.shuffle(image_names)
                
                if split == 'train':
                    image_names = image_names[:int(0.6 * len(image_names))]
                elif split == 'val':
                    image_names = image_names[int(0.6 * len(image_names)):int(0.8 * len(image_names))]
                else:  # split == 'test'
                    image_names = image_names[int(0.8 * len(image_names)):]
                
                for image_name in image_names:
                    self.image_paths.append(os.path.join(subfolder_dir, image_name))
                    self.labels.append(i)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, index):
        image_path = self.image_paths[index]
        label = self.labels[index]
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        label = torch.tensor(label, dtype=torch.long)
        label = F.one_hot(label, num_classes=30)
           
        return image, label

In [4]:
from torchvision import transforms

# Create the datasets and data loaders
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = WasteDataset(path_append + dataset_path, split='train', transform=transform)
val_dataset = WasteDataset(path_append + dataset_path, split='val', transform=transform)
test_dataset = WasteDataset(path_append + dataset_path, split='test', transform=transform)

X, y = train_dataset[0]
print(X.shape)
print(y.shape)

torch.Size([3, 224, 224])
torch.Size([30])


In [5]:

data_config = DataConfig(dataset_name = 'recycle_image', task_type='multi_label_classification', obs_shape=[3, 224, 224], label_size=30)

#  Set training configuration from the AlgorithmConfig class, returning them as a Namespace object.
ml_params = MLParameters(core_model = 'gpt', encoder_model = 'resnet')

In [6]:
from trainer_hub import TrainerHub

# Set the device to GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

# Initialize the TrainerHub class with the training configuration, data configuration, device, and use_print and use_wandb flags
trainer_hub = TrainerHub(ml_params, data_config, device, use_print=True, use_wandb=False)

In [7]:
trainer_hub.train(train_dataset, val_dataset)

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Iterations:   0%|          | 0/140 [00:00<?, ?it/s]

  return F.conv2d(input, weight, bias, self.stride,


[0/100][50/140][Time 43.10]
Unified LR across all optimizers: 0.0001995308238189185
--------------------Training Metrics--------------------
Trainer:  resnet
Inf: 1.3029	Gen: 1.2072	Rec: 1.3920	E: 1.1181	R: 1.4877	P: 1.2964
Trainer:  gpt
Inf: 0.2419	Gen: 0.5568	Rec: 0.5008	E: 0.2979	R: 0.1858	P: 0.8157
--------------------Test Metrics------------------------
accuracy: 0.4781

[0/100][100/140][Time 41.53]
Unified LR across all optimizers: 0.00019907191565870155
--------------------Training Metrics--------------------
Trainer:  resnet
Inf: 1.1712	Gen: 1.0934	Rec: 1.2027	E: 1.0620	R: 1.2804	P: 1.1249
Trainer:  gpt
Inf: 0.0374	Gen: 0.3859	Rec: 0.3840	E: 0.0393	R: 0.0355	P: 0.7325
--------------------Test Metrics------------------------
accuracy: 0.6176



Iterations:   0%|          | 0/140 [00:00<?, ?it/s]

[1/100][10/140][Time 41.87]
Unified LR across all optimizers: 0.00019861406295796434
--------------------Training Metrics--------------------
Trainer:  resnet
Inf: 1.1155	Gen: 1.0590	Rec: 1.1533	E: 1.0212	R: 1.2097	P: 1.0968
Trainer:  gpt
Inf: 0.0431	Gen: 0.4393	Rec: 0.4349	E: 0.0475	R: 0.0387	P: 0.8312
--------------------Test Metrics------------------------
accuracy: 0.5712

[1/100][60/140][Time 41.53]
Unified LR across all optimizers: 0.00019815726328921765
--------------------Training Metrics--------------------
Trainer:  resnet
Inf: 1.0765	Gen: 1.0427	Rec: 1.1326	E: 0.9866	R: 1.1665	P: 1.0988
Trainer:  gpt
Inf: 0.1032	Gen: 0.4885	Rec: 0.4720	E: 0.1196	R: 0.0868	P: 0.8573
--------------------Test Metrics------------------------
accuracy: 0.8053

[1/100][110/140][Time 41.56]
Unified LR across all optimizers: 0.00019770151423055492
--------------------Training Metrics--------------------
Trainer:  resnet
Inf: 1.0512	Gen: 1.0274	Rec: 1.1129	E: 0.9658	R: 1.1367	P: 1.0890
Trainer:  gpt


Iterations:   0%|          | 0/140 [00:00<?, ?it/s]

[2/100][20/140][Time 41.74]
Unified LR across all optimizers: 0.00019724681336564005
--------------------Training Metrics--------------------
Trainer:  resnet
Inf: 1.0484	Gen: 1.0242	Rec: 1.1056	E: 0.9670	R: 1.1298	P: 1.0814
Trainer:  gpt
Inf: 0.0564	Gen: 0.3590	Rec: 0.3547	E: 0.0607	R: 0.0522	P: 0.6573
--------------------Test Metrics------------------------
accuracy: 0.8870

[2/100][70/140][Time 41.89]
Unified LR across all optimizers: 0.00019679315828369438
--------------------Training Metrics--------------------
Trainer:  resnet
Inf: 1.0522	Gen: 1.0226	Rec: 1.1039	E: 0.9709	R: 1.1335	P: 1.0742
Trainer:  gpt
Inf: 0.0598	Gen: 0.3738	Rec: 0.3690	E: 0.0646	R: 0.0550	P: 0.6830
--------------------Test Metrics------------------------
accuracy: 0.8979

[2/100][120/140][Time 41.89]
Unified LR across all optimizers: 0.00019634054657948372
--------------------Training Metrics--------------------
Trainer:  resnet
Inf: 1.0422	Gen: 1.0110	Rec: 1.0946	E: 0.9585	R: 1.1258	P: 1.0634
Trainer:  gpt


Iterations:   0%|          | 0/140 [00:00<?, ?it/s]

[3/100][30/140][Time 41.65]
Unified LR across all optimizers: 0.00019588897585330582
--------------------Training Metrics--------------------
Trainer:  resnet
Inf: 1.0356	Gen: 1.0089	Rec: 1.0859	E: 0.9586	R: 1.1125	P: 1.0592
Trainer:  gpt
Inf: 0.0484	Gen: 0.3556	Rec: 0.3531	E: 0.0510	R: 0.0459	P: 0.6603
--------------------Test Metrics------------------------
accuracy: 0.8931

[3/100][80/140][Time 41.56]
Unified LR across all optimizers: 0.00019543844371097777
--------------------Training Metrics--------------------
Trainer:  resnet
Inf: 1.0405	Gen: 1.0130	Rec: 1.0888	E: 0.9647	R: 1.1162	P: 1.0613
Trainer:  gpt
Inf: 0.0445	Gen: 0.3630	Rec: 0.3611	E: 0.0464	R: 0.0426	P: 0.6797
--------------------Test Metrics------------------------
accuracy: 0.8988

[3/100][130/140][Time 41.97]
Unified LR across all optimizers: 0.00019498894776382288
--------------------Training Metrics--------------------
Trainer:  resnet
Inf: 1.0437	Gen: 1.0142	Rec: 1.0926	E: 0.9653	R: 1.1220	P: 1.0631
Trainer:  gpt


Iterations:   0%|          | 0/140 [00:00<?, ?it/s]

[4/100][40/140][Time 41.82]
Unified LR across all optimizers: 0.00019454048562865856
--------------------Training Metrics--------------------
Trainer:  resnet
Inf: 1.0342	Gen: 1.0093	Rec: 1.0854	E: 0.9581	R: 1.1104	P: 1.0605
Trainer:  gpt
Inf: 0.0403	Gen: 0.3789	Rec: 0.3780	E: 0.0413	R: 0.0393	P: 0.7166
--------------------Test Metrics------------------------
accuracy: 0.9116

[4/100][90/140][Time 41.83]
Unified LR across all optimizers: 0.00019409305492778308
--------------------Training Metrics--------------------
Trainer:  resnet
Inf: 1.0344	Gen: 1.0054	Rec: 1.0815	E: 0.9583	R: 1.1106	P: 1.0525
Trainer:  gpt
Inf: 0.0377	Gen: 0.4065	Rec: 0.4053	E: 0.0389	R: 0.0365	P: 0.7740
--------------------Test Metrics------------------------
accuracy: 0.9267



Iterations:   0%|          | 0/140 [00:00<?, ?it/s]

[5/100][0/140][Time 41.91]
Unified LR across all optimizers: 0.00019364665328896346
--------------------Training Metrics--------------------
Trainer:  resnet
Inf: 1.0452	Gen: 1.0080	Rec: 1.0835	E: 0.9696	R: 1.1207	P: 1.0463
Trainer:  gpt
Inf: 0.0379	Gen: 0.4135	Rec: 0.4122	E: 0.0392	R: 0.0366	P: 0.7878
--------------------Test Metrics------------------------
accuracy: 0.9199

[5/100][50/140][Time 41.81]
Unified LR across all optimizers: 0.00019320127834542263
--------------------Training Metrics--------------------
Trainer:  resnet
Inf: 1.0423	Gen: 1.0059	Rec: 1.0812	E: 0.9669	R: 1.1176	P: 1.0448
Trainer:  gpt
Inf: 0.0406	Gen: 0.4170	Rec: 0.4155	E: 0.0420	R: 0.0391	P: 0.7920
--------------------Test Metrics------------------------
accuracy: 0.9163

