In [None]:
# -*- coding: utf-8 -*-
"""
Created on Tue Dec 17 00:00:00 2023
@author: chun (refactored)
"""
import os
import re
import glob
import time
import torch
import yaml
import numpy as np
import torch.nn as nn
import torch.optim as optim
from fractions import Fraction
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch.nn.parallel import DataParallel
from tqdm import tqdm
from tensorboardX import SummaryWriter

from model2 import DeepJSCC, ratio2filtersize
from utils import image_normalization, set_seed, view_model_param
from dataset import Vanilla


import torch
from model import  DeepJSCC as DJ

def config_parser_pipeline():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='imagenet', type=str,
                        choices=['cifar10', 'imagenet'], help='Dataset name')
    parser.add_argument('--out', default='./out', type=str, help='Output path for logs and checkpoints')
    parser.add_argument('--disable_tqdm', action='store_true', help='Disable tqdm progress bars')
    parser.add_argument('--device', default='cuda:0', type=str, help='Device: cuda:0 / cpu')
    parser.add_argument('--parallel', action='store_true', help='Use DataParallel if multiple GPUs are available')
    parser.add_argument('--snr_list', default=['11', '19'], nargs='+',
                        help='List of SNR values (e.g. 5 10 15)')
    parser.add_argument('--ratio_list', default=['1/6', '1/12'], nargs='+',
                        help='List of channel ratios (e.g. 1/6 1/12)')
    parser.add_argument('--channel', default='Rayleigh', type=str,
                        choices=['AWGN', 'Rayleigh'], help='Channel type')


    args, unknown = parser.parse_known_args()
    if unknown:
        print(f"Ignoring unknown args: {unknown}")
    return args

def auto_find_checkpoint0(dataset, c, snr, ratio, channel, base_dir='./out/checkpoints'):
    prefix = f"{dataset.upper()}_{c}_{snr}_{ratio:.2f}_{channel}"
    candidates = [
        os.path.join(base_dir, d)
        for d in os.listdir(base_dir)
        if os.path.isdir(os.path.join(base_dir, d)) and d.startswith(prefix)
    ]
    if not candidates:
        raise FileNotFoundError(f"No checkpoint directories found with prefix: {prefix}")
    latest_dir = max(candidates, key=os.path.getmtime)
    ckpts = glob.glob(os.path.join(latest_dir, 'epoch_*.pth'))
    if not ckpts:
        raise FileNotFoundError(f"No checkpoint files in: {latest_dir}")
    latest_ckpt = sorted(ckpts, key=os.path.getmtime)[-1]
    print(f"Found checkpoint: {latest_ckpt}")
    return latest_ckpt



def load_model(ratio):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    image_size = (64, 64)

    # checkpoint_path = f'/home/MATLAB_DATA/TiNguyen/Deep-JSCC-PyTorch/out/checkpoints/IMAGENET_8_19.0_{ratio:.2f}_AWGN_16h18m57s_on_Jul_15_2025/epoch_499.pth'

    DATASET = 'imagenet'
    CHANNEL_TYPE = 'AWGN'

    dummy_img = torch.randn(3, *image_size)
    c = ratio2filtersize(dummy_img, ratio)
    snr = 19.0
    checkpoint_path = auto_find_checkpoint0(DATASET, c, snr, ratio, CHANNEL_TYPE)


    model = DJ(c=c, snr=snr, channel_type='AWGN')
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.to(device)
    model.eval()
    return model



def train_epoch(model, optimizer, param, data_loader, Nsamples):
    model.train()
    total_loss = 0.0
    for it, (images, _) in enumerate(data_loader):
        images = images.to(param['device'])
        optimizer.zero_grad()
        # outputs = model(images)
        # outputs_Enc = model.Enc(images) put here cause error because trying to use a computational graph that was already freed after a .backward() call 
        for _ in range(Nsamples):
            outputs_Enc = model.Enc(images)

            outputs = model.Chan(outputs_Enc)

            outputs = model.Dec(outputs)

            outputs = image_normalization('denormalization')(outputs)
            images1 = image_normalization('denormalization')(images)
            loss = model.loss(images1, outputs) if not param['parallel'] else model.module.loss(images1, outputs)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
    num_batches = (it + 1) * Nsamples
    avg_loss = total_loss / num_batches
    return avg_loss



def evaluate_epoch(model, param, data_loader):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for it, (images, _) in enumerate(data_loader):
            images = images.to(param['device'])
            outputs_Enc = model.Enc(images)
            outputs = model.Chan(outputs_Enc)
            outputs = model.Dec(outputs)
            outputs = image_normalization('denormalization')(outputs)
            images = image_normalization('denormalization')(images)
            loss = (model.loss(images, outputs)
                    if not param['parallel'] else model.module.loss(images, outputs))
            total_loss += loss.item()
    avg_loss = total_loss / (it + 1)
    return avg_loss





def main_pipeline(Nsamples):
    args = config_parser_pipeline()
    args.snr_list = list(map(float, args.snr_list))
    args.ratio_list = [float(Fraction(x)) for x in args.ratio_list]

    print("📡 Training Start")
    for ratio in args.ratio_list:
        for snr in args.snr_list:
            params = prepare_params(args, ratio, snr)
            train_pipeline(params, Nsamples)


def prepare_params(args, ratio, snr):
    params = {
        'disable_tqdm': args.disable_tqdm,
        'dataset': args.dataset,
        'out_dir': args.out,
        'device': args.device if torch.cuda.is_available() else 'cpu',
        'parallel': args.parallel,
        'snr': snr,
        'ratio': ratio,
        'channel': args.channel,
    }

    if args.dataset == 'cifar10':
        params.update({
            'batch_size': 64, 'num_workers': 4, 'epochs': 1000,
            'init_lr': 1e-3, 'weight_decay': 5e-4,
            'if_scheduler': True, 'step_size': 640, 'gamma': 0.1,
            'ReduceLROnPlateau': False, 'lr_reduce_factor': 0.5,
            'lr_schedule_patience': 15, 'min_lr': 1e-7, 'max_time': 12,
            'seed': 42,
        })
    else:  # imagenet
        params.update({
            'batch_size': 32, 'num_workers': 4, 'epochs': 750,
            'init_lr': 1e-4, 'weight_decay': 5e-4,
            'if_scheduler': True, 'gamma': 0.1,
            'ReduceLROnPlateau': True, 'lr_reduce_factor': 0.5,
            'lr_schedule_patience': 15, 'min_lr': 1e-8, 'max_time': 12,
            'seed': 42,
        })

    set_seed(params['seed'])
    return params


def train_pipeline(params, Nsamples):
    # Data setup
    transform = transforms.Compose([transforms.ToTensor()])
    if params['dataset'] == 'cifar10':
        train_ds = datasets.CIFAR10(root='../dataset/', train=True, download=True, transform=transform)
        test_ds = datasets.CIFAR10(root='../dataset/', train=False, download=True, transform=transform)
    else:  # imagenet
        transform_resize = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor()])
        # train_ds = datasets.ImageFolder('../dataset/ImageNet/train', transform=transform_resize)
        # # test_ds = Vanilla('../dataset/ImageNet/val', transform=transform_resize)
        # test_ds = datasets.ImageFolder('../dataset/ImageNet/val', transform=transform_resize)

        train_ds = datasets.ImageFolder('/home/MATLAB_DATA/TiNguyen/Sentry_Data/train', transform=transform_resize)
        # test_ds = Vanilla('../dataset/ImageNet/val', transform=transform_resize)
        test_ds = datasets.ImageFolder('/home/MATLAB_DATA/TiNguyen/Sentry_Data/test', transform=transform_resize)

    train_loader = DataLoader(train_ds, batch_size=params['batch_size'],
                              shuffle=True, num_workers=params['num_workers'])
    test_loader = DataLoader(test_ds, batch_size=params['batch_size'],
                             shuffle=False, num_workers=params['num_workers'])

    # Model init
    sample_img, _ = train_ds[0]
    c = ratio2filtersize(sample_img, params['ratio'])
    print(f"🔧 SNR={params['snr']}, inner channel c={c}, ratio={params['ratio']:.2f}")



    model = DeepJSCC(c=c, channel_type=params['channel'], snr=params['snr'], M =16, num_e_bits = 5, num_m_bits = 10)


    ## model reference:
    model_ref = load_model(params['ratio'])



    model.encoder.load_state_dict(model_ref.encoder.state_dict())
    model.decoder.load_state_dict(model_ref.decoder.state_dict())



    model = setup_model_device(model, params)

    # Optimizer and scheduler
    optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay'])
    scheduler = setup_scheduler(optimizer, params)

    # Logging directories
    phaser = f"{params['dataset'].upper()}_{c}_{params['snr']}_{params['ratio']:.2f}_{params['channel']}_{Nsamples}_{time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')}"
    # phaser = f"{params['dataset'].upper()}_{c}_{params['snr']}_{params['ratio']:.2f}_{params['channel']}"
    log_dir = os.path.join(params['out_dir'], 'logs', phaser)
    ckpt_dir = os.path.join(params['out_dir'], 'checkpoints', phaser)
    os.makedirs(ckpt_dir, exist_ok=True)

    writer = SummaryWriter(log_dir=log_dir)
    writer.add_text('config', str(params))

    t_start = time.time()
    best_val = float('inf')

    try:
        for epoch in tqdm(range(params['epochs']), disable=params['disable_tqdm'], desc='Epoch'):
            t0 = time.time()
            train_loss = train_epoch(model, optimizer, params, train_loader, Nsamples)
            val_loss = evaluate_epoch(model, params, test_loader)

            writer.add_scalar('train_loss', train_loss, epoch)
            writer.add_scalar('val_loss', val_loss, epoch)
            writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

            if epoch%10==0:
                tqdm.write(f"E{epoch} L_train={train_loss:.4f} L_val={val_loss:.4f} LR={optimizer.param_groups[0]['lr']:.4e} EpochTime={(time.time()-t0):.2f}s")

            # Save checkpoint
            torch.save(model.state_dict(), os.path.join(ckpt_dir, f"epoch_{epoch}.pth"))
            cleanup_checkpoints(ckpt_dir, keep_latest=1)

            # Scheduler step
            if params['ReduceLROnPlateau'] and scheduler:
                scheduler.step(val_loss)
            elif params['if_scheduler'] and scheduler:
                scheduler.step()

            # Early stop trigger
            if optimizer.param_groups[0]['lr'] < params['min_lr']:
                print("LR dropped below minimum threshold.")
                break

            # Optional: can implement early stopping here

            # Time check
            if time.time() - t_start > params['max_time'] * 3600:
                print("Max training time reached, exiting.")
                break

    except KeyboardInterrupt:
        print("Training interrupted by user.")

    # Final evaluation
    final_train = evaluate_epoch(model, params, train_loader)
    final_val = evaluate_epoch(model, params, test_loader)

    print(f"Done: Train Loss={final_train:.4f}, Val Loss={final_val:.4f}, TotalTime={(time.time()-t_start)/3600:.2f}h")

    # Save config YAML
    config_path = os.path.join(params['out_dir'], 'configs', phaser + '.yaml')
    os.makedirs(os.path.dirname(config_path), exist_ok=True)
    with open(config_path, 'w') as f:
        yaml.dump({'params': params, 'inner_channel': c, 'total_parameters': view_model_param(model)}, f)

    writer.close()


def setup_model_device(model, params):
    device = torch.device(params['device'])
    if params['parallel'] and torch.cuda.device_count() > 1:
        model = DataParallel(model).to(device)
    else:
        model = model.to(device)
    model.loss = model.module.loss if isinstance(model, DataParallel) else model.loss
    params['device'] = device
    params['parallel'] = params['parallel'] and torch.cuda.device_count() > 1
    return model


def setup_scheduler(optimizer, params):
    if not params['if_scheduler']:
        return None
    if params['ReduceLROnPlateau']:
        return optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',
                                                    factor=params['lr_reduce_factor'],
                                                    patience=params['lr_schedule_patience'])
    else:
        return optim.lr_scheduler.StepLR(optimizer, step_size=params['step_size'], gamma=params['gamma'])


def cleanup_checkpoints(directory, keep_latest=2):
    files = sorted(glob.glob(os.path.join(directory, 'epoch_*.pth')), key=os.path.getmtime)
    old = files[:-keep_latest]
    for f in old:
        os.remove(f)


if __name__ == "__main__":
    Nsamples = 100 # for each encoded output, add Nsamples random channels
    main_pipeline(Nsamples)


Ignoring unknown args: ['--f=/run/user/1004/jupyter/runtime/kernel-v3ebb5183588a0404d0c36c4c59164d09354c3f9f5.json']
📡 Training Start
🔧 SNR=11.0, inner channel c=8, ratio=0.17
Found checkpoint: ./out/checkpoints/IMAGENET_8_19.0_0.17_AWGN_10_21h24m09s_on_Jul_18_2025/epoch_502.pth


  model.load_state_dict(torch.load(checkpoint_path, map_location=device))
