In [1]:
# !pip install datasets
# !pip install transformers
# !pip install transformers[torch]
# !pip install 'accelerate>=0.26.0'

In [2]:
# # uncomment the following lines to install the necessary data from huggingface (wait till 1500 .hf folders install)
# !sudo apt-get install git-lfs
# !git lfs install 
# !git clone https://huggingface.co/datasets/tuan124816/newcs2_data

In [3]:
import os
from dataclasses import dataclass
from datasets import Image
import numpy as np
import torch
from datasets import load_from_disk
from transformers import DecisionTransformerConfig, DecisionTransformerModel, Trainer, TrainingArguments
import torch.nn.functional as F
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True  
from PIL import Image
from torchvision.models import efficientnet_b0
from torchvision.transforms import Compose, Normalize, ToTensor
from torch.quantization import quantize_dynamic
os.environ["WANDB_DISABLED"] = "true" 
from datasets import disable_progress_bars
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
from functools import partial
import shutil
disable_progress_bars()

Using device: cuda


In [4]:
# Define preprocessing steps for EfficientNet
efficientnet_preprocessor = Compose([
    ToTensor(),
    Normalize(mean=[0.5], std=[0.5])  # Normalization
])

efficientnet = efficientnet_b0(pretrained=True).to(device)
efficientnet = quantize_dynamic(efficientnet, {torch.nn.Linear}, dtype=torch.qint8)
efficientnet.eval()



EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivat

In [5]:
def offset_observations(example, offset_amount = 1, offset_column = 1):
    # Roll the observations by the specified offset amount along the timestep dimension
    if offset_column >= 1:
        example['observations'] = torch.roll(torch.tensor(example['observations']).clone().detach(), shifts=offset_amount, dims=0).tolist()
    if offset_column >= 2:
        example['actions'] = torch.roll(torch.tensor(example['actions']).clone().detach(), shifts=offset_amount, dims=0).tolist()
    if offset_column >= 3:
        example['rewards'] = torch.roll(torch.tensor(example['rewards']).clone().detach(), shifts=offset_amount, dims=0).tolist()
    return example

In [6]:
def resize_img(img):
    # Convert 1D array to 2D if needed (e.g., 150x412 -> 224x224)
    img = img.numpy().reshape(150, 412)
    img_pil = Image.fromarray(img.astype(np.uint8))  # Convert to PIL image
    
    # Resize to 224x224 for EfficientNet
    img_resized = img_pil.resize((224, 224))

    # Convert grayscale (1 channel) to RGB (3 channels)
    img_rgb = np.stack([np.array(img_resized)] * 3, axis=-1)
    img_rgb = Image.fromarray(img_rgb)  # Convert back to PIL Image

    # Apply EfficientNet preprocessing
    img_tensor = efficientnet_preprocessor(img_rgb)
    return img_tensor

def extract_features_with_efficientnet(images, batch_size=8):
    # Preprocess and extract features in smaller batches
    pooled_features = []
    for i in range(0, len(images), batch_size):
        batch_images = images[i:i + batch_size]
        processed_images = torch.stack(batch_images).to(device)

        with torch.no_grad():
            outputs = efficientnet.features(processed_images)  # Intermediate features
            pooled_batch = torch.mean(outputs, dim=(2, 3))  # Global Average Pooling
        pooled_features.append(pooled_batch)

    return torch.cat(pooled_features, dim=0)


def get_states_and_images(dataset):
    states = []
    images = []

    for mini_dts in dataset.with_format("torch"):
        for state in mini_dts["observations"]:
            states.append(state[:3]) # This take the state part of the state
            images.append(resize_img(state[3:])) # This take the image part of the state (1D image)

    return states, images


def get_new_state(dataset, model):
    states_list, images_list = get_states_and_images(dataset)
    states_features = torch.stack(states_list).to(device)
    
    # Extract image features with EfficientNet-B0
    image_features = extract_features_with_efficientnet(images_list)

    # Combine state and image features
    combined_features = torch.cat((states_features, image_features), dim=1)
    new_features = combined_features.reshape(20, 50, -1)

    return new_features

In [7]:
def update_observations(example, idx):
    example['observations'] = new_obs_features[idx]
    return example

In [8]:
# This part if use load_from_disk
import random
def get_all_hf_folders(base_dir):
    hf_folders = []
    for root, dirs, files in os.walk(base_dir):
        for dir_name in dirs:
            if dir_name.endswith('.hf'):
                hf_folders.append(os.path.join(root, dir_name))
    return hf_folders

# base_dir = "d:\\manual_data"
base_dir = "newcs2_data"
# base_dir = "d:\\newcs2_data_1"
hf_folder_paths = get_all_hf_folders(base_dir)
random.shuffle(hf_folder_paths)
def load_dataset_generator(hf_folder_paths): # pass in the list of arrow file paths
    for folder in hf_folder_paths:
        try: # Fall back to load_from_disk if fail and get DatasetGenerationError when using load_dataset
            print(folder)
            yield load_from_disk(folder).shuffle(seed=42)
            
        except Exception as e:
            print(f"Failed to load with 'load_from_disk'. Error: {e}")

dataset_stream = load_dataset_generator(hf_folder_paths)

In [None]:
first_dataset = next(iter(dataset_stream))
first_dataset = first_dataset.with_format("torch")
new_obs_features = get_new_state(first_dataset, efficientnet)
first_dataset = first_dataset.map(update_observations, with_indices=True)

In [10]:
@dataclass
class DecisionTransformerGymDataCollator:
    return_tensors: str = "pt"
    max_len: int = 50 #subsets of the episode we use for training
    state_dim: int = 1283  # size of state space
    act_dim: int = 51  # size of action space
    max_ep_len: int = 1000 # max episode length in the dataset
    scale: float = 1000.0  # normalization of rewards/returns
    state_mean: np.array = None  # to store state means
    state_std: np.array = None  # to store state stds
    p_sample: np.array = None  # a distribution to take account trajectory lengths
    n_traj: int = 0 # to store the number of trajectories in the dataset

    def __init__(self, dataset) -> None:
        self.act_dim = 51
        self.state_dim = 1283
        self.dataset = dataset.with_format('torch')
        self.p_sample = np.array([0.05] * 20)

    # def discount_cumsum(self, x, gamma): 
    #         return lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]

    def discount_cumsum(self, x, gamma):
        discount_cumsum = np.zeros_like(x)
        discount_cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0] - 1)):
            discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
        return discount_cumsum

    def __call__(self, batch):
        # for i in range(50):
        # self.dataset = self.dataset.map(offset_observations)
        offset_amount = 2
        offset_padding = random.randrange(0,30)
        
        offset_func = partial(offset_observations, offset_amount=offset_amount, offset_column = 1)
        padding_func = partial(offset_observations, offset_amount=offset_padding, offset_column = 3)
        self.dataset = self.dataset.map(offset_func)
        self.dataset = self.dataset.map(padding_func)


        mask =      np.concatenate((np.array([0] * (offset_amount+offset_padding)), np.array([1] * (self.max_len - offset_amount - offset_padding)  )   )  )
        time_step = np.concatenate((np.array([0] * (offset_amount+offset_padding)), np.arange(0,    self.max_len - offset_amount - offset_padding   )   )  )
        return {
            "states": self.dataset.with_format("torch")["observations"].to(device),#.to(device), # og
            "actions": self.dataset.with_format("torch")['actions'].to(device),#.to(device), # og
            "rewards": self.dataset.with_format("torch")['rewards'].reshape(20,50,1).to(device),#.to(device),
            "returns_to_go": torch.from_numpy(np.array(self.discount_cumsum(self.dataset["rewards"], gamma=1.0)).reshape(20, 50, 1)).float(),#.to(device),
            "timesteps":      torch.from_numpy(time_step).long().repeat(20, 1).to(device),#.to(device),
            "attention_mask": torch.from_numpy(mask).float().repeat(20, 1).to(device),#.to(device),
        }

In [11]:
# Model Decision Transformer part
class TrainableDT(DecisionTransformerModel):
    def __init__(self, config, gamma=0.99):
        super().__init__(config)
        self.gamma = gamma
        self.n_keys = 11
        self.n_clicks = 2
        mouse_x_possibles = [-1000.0,-500.0, -300.0, -200.0, -100.0, -60.0, -30.0, -20.0, -10.0, -4.0, -2.0, -0.0, 2.0, 4.0, 10.0, 20.0, 30.0, 60.0, 100.0, 200.0, 300.0, 500.0,1000.0]
        mouse_y_possibles = [-200.0, -100.0, -50.0, -20.0, -10.0, -4.0, -2.0, -0.0, 2.0, 4.0, 10.0, 20.0, 50.0, 100.0, 200.0]
        self.n_mouse_x = len(mouse_x_possibles)
        self.n_mouse_y = len(mouse_y_possibles)

    def forward(self, **kwargs):
        global model_output, predict_value, targeter_value
        # Reshape for custom loss

        output = super().forward(**kwargs)
        model_output = output
        action_preds = output[1] 
        action_targets = kwargs["actions"]
        attention_mask = kwargs["attention_mask"]
        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        predict_value = action_preds
        action_targets = action_targets.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        targeter_value = action_targets

        # print('|'*50)
        # print("action_preds", action_preds)
        # print("action_targets", action_targets)
        # print('|'*50)
        
        # Loss calculations based on each category
        loss_wasd = F.binary_cross_entropy_with_logits(action_preds[:, :4], action_targets[:, :4])
        loss_space = F.binary_cross_entropy_with_logits(action_preds[:, 4:5], action_targets[:, 4:5])
        loss_weapon_switch = F.binary_cross_entropy_with_logits(action_preds[:, self.n_keys-4:self.n_keys-1], action_targets[:, self.n_keys-4:self.n_keys-1])
        loss_reload = F.binary_cross_entropy_with_logits(action_preds[:, self.n_keys-1:self.n_keys], action_targets[:, self.n_keys-1:self.n_keys])
        loss_left_click = F.binary_cross_entropy_with_logits(action_preds[:, self.n_keys:self.n_keys+1], action_targets[:, self.n_keys:self.n_keys+1])
        loss_right_click = F.binary_cross_entropy_with_logits(action_preds[:, self.n_keys+1:self.n_keys+self.n_clicks], action_targets[:, self.n_keys+1:self.n_keys+self.n_clicks])
        loss_mouse_move_x = F.cross_entropy(action_preds[:, self.n_keys+self.n_clicks:self.n_keys+self.n_clicks+self.n_mouse_x], action_targets[:, self.n_keys+self.n_clicks:self.n_keys+self.n_clicks+self.n_mouse_x])
        loss_mouse_move_y = F.cross_entropy(action_preds[:, self.n_keys+self.n_clicks+self.n_mouse_x:self.n_keys+self.n_clicks+self.n_mouse_x+self.n_mouse_y], action_targets[:, self.n_keys+self.n_clicks+self.n_mouse_x:self.n_keys+self.n_clicks+self.n_mouse_x+self.n_mouse_y])

        
        # Total loss
        total_loss = sum([loss_wasd,loss_space, loss_weapon_switch, loss_reload, loss_left_click, loss_right_click, loss_mouse_move_x, loss_mouse_move_y])
        print("+"*50)
        print("loss_wasd", loss_wasd, end='|||')
        print("loss_left_click", loss_left_click)
        print("loss_mouse_move_x", loss_mouse_move_x, end='|||')
        print("loss_mouse_move_y", loss_mouse_move_y)
        print("total loss", total_loss)
        print("+"*50)
        return {"loss": total_loss}

    def original_forward(self, **kwargs):
        return super().forward(**kwargs)

In [None]:
training_args = TrainingArguments(
    output_dir="trained_models_2/",
    remove_unused_columns=False,
    num_train_epochs=25,
    per_device_train_batch_size=1, 
    learning_rate=5e-5,
    weight_decay=0.0001,
    optim = "adamw_torch_fused",
    # lr_scheduler_type="cosine",
    max_grad_norm=1.0,
    dataloader_pin_memory = False,
    tf32=True,
)

In [None]:
collator = DecisionTransformerGymDataCollator(first_dataset)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=first_dataset,
    data_collator=collator,
)

# train the first dataset
trainer.train()
save_path = "trained_models_2/model_after_dataset_1_in_loops_0"
trainer.save_model(save_path)
print(f"Model saved to {save_path}")

In [None]:
for loops in range(100):
    start_num = 53 # custome theo số cuối của model trước + 1
    for i, dataset in enumerate(dataset_stream, start=start_num):
        dataset = dataset.with_format("torch")
        new_obs_features = get_new_state(dataset, efficientnet)
        dataset = dataset.map(update_observations, with_indices=True)
        collator = DecisionTransformerGymDataCollator(dataset)
        print(f"USE ++ model_after_dataset_{i-1} ++ FOR TRAINING NOW")
        # trained_models_2/model_after_dataset_1_in_loops_0
        model = TrainableDT.from_pretrained(f"trained_models_2/model_after_dataset_{i-1}_in_loops_{loops}").to(device)
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
            data_collator=collator,
        )

        trainer.train()
        save_path = f"trained_models_2/model_after_dataset_{i}_in_loops_{loops}"
        trainer.save_model(save_path)
        if os.path.exists(f"trained_models_2/model_after_dataset_{i-2}_in_loops_{loops}") and i-2 % 10 == 0:
            shutil.rmtree(f"trained_models_2/model_after_dataset_{i-2}_in_loops_{loops}")
        print(f"Model saved to {save_path}")
        
    dataset_stream = load_dataset_generator(hf_folder_paths)
    first_dataset = next(dataset_stream).with_format('torch')
    new_obs_features = get_new_state(first_dataset, efficientnet)
    dataset = first_dataset.map(update_observations, with_indices=True)

    collator = DecisionTransformerGymDataCollator(dataset)
    print(f"USE ++ model_after_dataset_{i-1} ++ FOR TRAINING NOW")

    model = TrainableDT.from_pretrained(f"trained_models_2/model_after_dataset_{len(hf_folder_paths)-1}_in_loops_{loops}").to(device)
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=collator,
    )
    trainer.train()
    save_path = f"trained_models_2/model_after_dataset_1_in_loops_{loops+1}"
    trainer.save_model(save_path)