In [None]:
# -*- coding: utf-8 -*-
"""
Created on Tue Dec 17 00:00:00 2023
@author: chun (refactored)
"""
import os
import re
import glob
import time
import torch
import yaml
import numpy as np
import torch.nn as nn
import torch.optim as optim
from fractions import Fraction
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch.nn.parallel import DataParallel
from tqdm import tqdm
from tensorboardX import SummaryWriter

from model import DeepJSCC, ratio2filtersize
from utils import image_normalization, set_seed, view_model_param
from dataset import Vanilla


def train_epoch(model, optimizer, param, data_loader):
    model.train()
    total_loss = 0.0
    for it, (images, _) in enumerate(data_loader):
        images = images.to(param['device'])
        optimizer.zero_grad()
        outputs = model(images)
        outputs = image_normalization('denormalization')(outputs)
        images = image_normalization('denormalization')(images)
        loss = model.loss(images, outputs) if not param['parallel'] else model.module.loss(images, outputs)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / (it + 1)
    return avg_loss


def evaluate_epoch(model, param, data_loader):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for it, (images, _) in enumerate(data_loader):
            images = images.to(param['device'])
            outputs = model(images)
            outputs = image_normalization('denormalization')(outputs)
            images = image_normalization('denormalization')(images)
            loss = (model.loss(images, outputs)
                    if not param['parallel'] else model.module.loss(images, outputs))
            total_loss += loss.item()
    avg_loss = total_loss / (it + 1)
    return avg_loss


def config_parser_pipeline():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='imagenet', type=str,
                        choices=['cifar10', 'imagenet'], help='Dataset name')
    parser.add_argument('--out', default='./out', type=str, help='Output path for logs and checkpoints')
    parser.add_argument('--disable_tqdm', action='store_true', help='Disable tqdm progress bars')
    parser.add_argument('--device', default='cuda:0', type=str, help='Device: cuda:0 / cpu')
    parser.add_argument('--parallel', action='store_true', help='Use DataParallel if multiple GPUs are available')
    parser.add_argument('--snr_list', default=['19', '13', '7', '4', '1'], nargs='+',
                        help='List of SNR values (e.g. 5 10 15)')
    parser.add_argument('--ratio_list', default=['1/6', '1/12'], nargs='+',
                        help='List of channel ratios (e.g. 1/6 1/12)')
    parser.add_argument('--channel', default='AWGN', type=str,
                        choices=['AWGN', 'Rayleigh'], help='Channel type')

    args, unknown = parser.parse_known_args()
    if unknown:
        print(f"Ignoring unknown args: {unknown}")
    return args


def main_pipeline():
    args = config_parser_pipeline()
    args.snr_list = list(map(float, args.snr_list))
    args.ratio_list = [float(Fraction(x)) for x in args.ratio_list]

    print("📡 Training Start")
    for ratio in args.ratio_list:
        for snr in args.snr_list:
            params = prepare_params(args, ratio, snr)
            train_pipeline(params)


def prepare_params(args, ratio, snr):
    params = {
        'disable_tqdm': args.disable_tqdm,
        'dataset': args.dataset,
        'out_dir': args.out,
        'device': args.device if torch.cuda.is_available() else 'cpu',
        'parallel': args.parallel,
        'snr': snr,
        'ratio': ratio,
        'channel': args.channel,
    }

    if args.dataset == 'cifar10':
        params.update({
            'batch_size': 64, 'num_workers': 4, 'epochs': 1000,
            'init_lr': 1e-3, 'weight_decay': 5e-4,
            'if_scheduler': True, 'step_size': 640, 'gamma': 0.1,
            'ReduceLROnPlateau': False, 'lr_reduce_factor': 0.5,
            'lr_schedule_patience': 15, 'min_lr': 1e-5, 'max_time': 12,
            'seed': 42,
        })
    else:  # imagenet
        params.update({
            'batch_size': 32, 'num_workers': 4, 'epochs': 500,
            'init_lr': 1e-4, 'weight_decay': 5e-4,
            'if_scheduler': True, 'gamma': 0.1,
            'ReduceLROnPlateau': True, 'lr_reduce_factor': 0.5,
            'lr_schedule_patience': 15, 'min_lr': 1e-5, 'max_time': 12,
            'seed': 42,
        })

    set_seed(params['seed'])
    return params


def train_pipeline(params):
    # Data setup
    transform = transforms.Compose([transforms.ToTensor()])
    if params['dataset'] == 'cifar10':
        train_ds = datasets.CIFAR10(root='../dataset/', train=True, download=True, transform=transform)
        test_ds = datasets.CIFAR10(root='../dataset/', train=False, download=True, transform=transform)
    else:  # imagenet
        transform_resize = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor()])
        # train_ds = datasets.ImageFolder('../dataset/ImageNet/train', transform=transform_resize)
        # # test_ds = Vanilla('../dataset/ImageNet/val', transform=transform_resize)
        # test_ds = datasets.ImageFolder('../dataset/ImageNet/val', transform=transform_resize)

        train_ds = datasets.ImageFolder('/home/MATLAB_DATA/TiNguyen/Sentry_Data/train', transform=transform_resize)
        # test_ds = Vanilla('../dataset/ImageNet/val', transform=transform_resize)
        test_ds = datasets.ImageFolder('/home/MATLAB_DATA/TiNguyen/Sentry_Data/test', transform=transform_resize)

    train_loader = DataLoader(train_ds, batch_size=params['batch_size'],
                              shuffle=True, num_workers=params['num_workers'])
    test_loader = DataLoader(test_ds, batch_size=params['batch_size'],
                             shuffle=False, num_workers=params['num_workers'])

    # Model init
    sample_img, _ = train_ds[0]
    c = ratio2filtersize(sample_img, params['ratio'])
    print(f"🔧 SNR={params['snr']}, inner channel c={c}, ratio={params['ratio']:.2f}")

    model = DeepJSCC(c=c, channel_type=params['channel'], snr=params['snr'])
    model = setup_model_device(model, params)

    # Optimizer and scheduler
    optimizer = optim.Adam(model.parameters(), lr=params['init_lr'], weight_decay=params['weight_decay'])
    scheduler = setup_scheduler(optimizer, params)

    # Logging directories
    # phaser = f"{params['dataset'].upper()}_{c}_{params['snr']}_{params['ratio']:.2f}_{params['channel']}_{time.strftime('%Hh%Mm%Ss_on_%b_%d_%Y')}"
    phaser = f"{params['dataset'].upper()}_{c}_{params['snr']}_{params['ratio']:.2f}_{params['channel']}"
    log_dir = os.path.join(params['out_dir'], 'logs', phaser)
    ckpt_dir = os.path.join(params['out_dir'], 'checkpoints', phaser)
    os.makedirs(ckpt_dir, exist_ok=True)

    writer = SummaryWriter(log_dir=log_dir)
    writer.add_text('config', str(params))

    t_start = time.time()
    best_val = float('inf')

    try:
        for epoch in tqdm(range(params['epochs']), disable=params['disable_tqdm'], desc='Epoch'):
            t0 = time.time()
            train_loss = train_epoch(model, optimizer, params, train_loader)
            val_loss = evaluate_epoch(model, params, test_loader)

            writer.add_scalar('train_loss', train_loss, epoch)
            writer.add_scalar('val_loss', val_loss, epoch)
            writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

            if epoch%100==0:
                tqdm.write(f"E{epoch} L_train={train_loss:.4f} L_val={val_loss:.4f} LR={optimizer.param_groups[0]['lr']:.4e} EpochTime={(time.time()-t0):.2f}s")

            # Save checkpoint
            torch.save(model.state_dict(), os.path.join(ckpt_dir, f"epoch_{epoch}.pth"))
            cleanup_checkpoints(ckpt_dir, keep_latest=2)

            # Scheduler step
            if params['ReduceLROnPlateau'] and scheduler:
                scheduler.step(val_loss)
            elif params['if_scheduler'] and scheduler:
                scheduler.step()

            # Early stop trigger
            if optimizer.param_groups[0]['lr'] < params['min_lr']:
                print("LR dropped below minimum threshold.")
                break

            # Optional: can implement early stopping here

            # Time check
            if time.time() - t_start > params['max_time'] * 3600:
                print("Max training time reached, exiting.")
                break

    except KeyboardInterrupt:
        print("Training interrupted by user.")

    # Final evaluation
    final_train = evaluate_epoch(model, params, train_loader)
    final_val = evaluate_epoch(model, params, test_loader)

    print(f"Done: Train Loss={final_train:.4f}, Val Loss={final_val:.4f}, TotalTime={(time.time()-t_start)/3600:.2f}h")

    # Save config YAML
    config_path = os.path.join(params['out_dir'], 'configs', phaser + '.yaml')
    os.makedirs(os.path.dirname(config_path), exist_ok=True)
    with open(config_path, 'w') as f:
        yaml.dump({'params': params, 'inner_channel': c, 'total_parameters': view_model_param(model)}, f)

    writer.close()


def setup_model_device(model, params):
    device = torch.device(params['device'])
    if params['parallel'] and torch.cuda.device_count() > 1:
        model = DataParallel(model).to(device)
    else:
        model = model.to(device)
    model.loss = model.module.loss if isinstance(model, DataParallel) else model.loss
    params['device'] = device
    params['parallel'] = params['parallel'] and torch.cuda.device_count() > 1
    return model


def setup_scheduler(optimizer, params):
    if not params['if_scheduler']:
        return None
    if params['ReduceLROnPlateau']:
        return optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',
                                                    factor=params['lr_reduce_factor'],
                                                    patience=params['lr_schedule_patience'])
    else:
        return optim.lr_scheduler.StepLR(optimizer, step_size=params['step_size'], gamma=params['gamma'])


def cleanup_checkpoints(directory, keep_latest=2):
    files = sorted(glob.glob(os.path.join(directory, 'epoch_*.pth')), key=os.path.getmtime)
    old = files[:-keep_latest]
    for f in old:
        os.remove(f)


if __name__ == "__main__":
    main_pipeline()


Ignoring unknown args: ['--f=/run/user/1004/jupyter/runtime/kernel-v3811fe931361453f97682a92081d5b7ada06276c7.json']
📡 Training Start
🔧 SNR=19.0, inner channel c=8, ratio=0.17


Epoch:   0%|          | 1/500 [00:05<41:44,  5.02s/it]

E0 L_train=545.6990 L_val=259.4189 LR=1.0000e-04 EpochTime=5.01s


Epoch:  20%|██        | 101/500 [08:36<36:18,  5.46s/it]

E100 L_train=36.9337 L_val=34.4725 LR=1.0000e-04 EpochTime=5.35s


Epoch:  40%|████      | 201/500 [15:25<16:51,  3.38s/it]

E200 L_train=24.2392 L_val=22.6311 LR=1.0000e-04 EpochTime=3.36s


Epoch:  60%|██████    | 301/500 [23:30<16:57,  5.11s/it]

E300 L_train=18.5636 L_val=17.2912 LR=1.0000e-04 EpochTime=5.23s


Epoch:  80%|████████  | 401/500 [30:26<05:33,  3.37s/it]

E400 L_train=15.1809 L_val=14.6833 LR=1.0000e-04 EpochTime=3.34s


Epoch: 100%|██████████| 500/500 [35:59<00:00,  4.32s/it]


Done: Train Loss=12.8711, Val Loss=12.4212, TotalTime=0.60h
🔧 SNR=13.0, inner channel c=8, ratio=0.17


Epoch:   0%|          | 1/500 [00:03<28:48,  3.46s/it]

E0 L_train=510.3064 L_val=241.5052 LR=1.0000e-04 EpochTime=3.46s


Epoch:  20%|██        | 101/500 [05:58<21:22,  3.21s/it]

E100 L_train=39.5009 L_val=36.4590 LR=1.0000e-04 EpochTime=3.17s


Epoch:  40%|████      | 201/500 [13:47<22:17,  4.47s/it]

E200 L_train=26.9483 L_val=25.5111 LR=1.0000e-04 EpochTime=5.07s


Epoch:  60%|██████    | 301/500 [21:41<17:16,  5.21s/it]

E300 L_train=22.2949 L_val=21.0286 LR=1.0000e-04 EpochTime=5.28s


Epoch:  80%|████████  | 401/500 [29:08<08:42,  5.28s/it]

E400 L_train=19.5643 L_val=18.7746 LR=1.0000e-04 EpochTime=5.31s


Epoch: 100%|██████████| 500/500 [37:26<00:00,  4.49s/it]


Done: Train Loss=18.6968, Val Loss=18.0622, TotalTime=0.62h
🔧 SNR=7.0, inner channel c=8, ratio=0.17


Epoch:   0%|          | 1/500 [00:03<28:59,  3.49s/it]

E0 L_train=506.9062 L_val=243.5118 LR=1.0000e-04 EpochTime=3.48s


Epoch:  20%|██        | 101/500 [07:10<35:31,  5.34s/it]

E100 L_train=45.7241 L_val=42.4450 LR=1.0000e-04 EpochTime=5.31s


Epoch:  40%|████      | 201/500 [14:27<15:32,  3.12s/it]

E200 L_train=35.1733 L_val=32.3941 LR=1.0000e-04 EpochTime=3.12s


Epoch:  60%|██████    | 301/500 [19:49<11:25,  3.45s/it]

E300 L_train=29.7133 L_val=29.0368 LR=1.0000e-04 EpochTime=3.52s


Epoch:  80%|████████  | 401/500 [25:21<05:30,  3.34s/it]

E400 L_train=26.3185 L_val=25.1916 LR=5.0000e-05 EpochTime=3.30s


Epoch: 100%|██████████| 500/500 [30:44<00:00,  3.69s/it]


Done: Train Loss=24.7648, Val Loss=23.7896, TotalTime=0.51h
🔧 SNR=4.0, inner channel c=8, ratio=0.17


Epoch:   0%|          | 1/500 [00:03<26:16,  3.16s/it]

E0 L_train=514.7751 L_val=252.0151 LR=1.0000e-04 EpochTime=3.16s


Epoch:  20%|██        | 101/500 [05:16<20:44,  3.12s/it]

E100 L_train=54.7512 L_val=50.9986 LR=1.0000e-04 EpochTime=3.24s


Epoch:  40%|████      | 201/500 [10:29<15:34,  3.13s/it]

E200 L_train=43.8230 L_val=41.1080 LR=1.0000e-04 EpochTime=3.07s


Epoch:  60%|██████    | 301/500 [15:43<10:18,  3.11s/it]

E300 L_train=39.1250 L_val=36.8404 LR=1.0000e-04 EpochTime=3.05s


Epoch:  80%|████████  | 401/500 [20:59<05:13,  3.16s/it]

E400 L_train=35.9653 L_val=34.1188 LR=1.0000e-04 EpochTime=3.08s


Epoch: 100%|██████████| 500/500 [25:58<00:00,  3.12s/it]


Done: Train Loss=33.3778, Val Loss=31.9655, TotalTime=0.43h
🔧 SNR=1.0, inner channel c=8, ratio=0.17


Epoch:   0%|          | 1/500 [00:03<25:38,  3.08s/it]

E0 L_train=533.8811 L_val=273.1242 LR=1.0000e-04 EpochTime=3.08s


Epoch:  20%|██        | 101/500 [05:17<20:55,  3.15s/it]

E100 L_train=67.1327 L_val=62.9045 LR=1.0000e-04 EpochTime=3.14s


Epoch:  40%|████      | 201/500 [10:34<15:37,  3.14s/it]

E200 L_train=56.7960 L_val=56.6744 LR=1.0000e-04 EpochTime=3.09s


Epoch:  60%|██████    | 301/500 [15:49<10:23,  3.13s/it]

E300 L_train=51.6576 L_val=48.8868 LR=1.0000e-04 EpochTime=3.16s


Epoch:  80%|████████  | 401/500 [20:51<04:52,  2.96s/it]

E400 L_train=48.6044 L_val=46.8629 LR=1.0000e-04 EpochTime=2.95s


Epoch: 100%|██████████| 500/500 [25:51<00:00,  3.10s/it]


Done: Train Loss=45.7772, Val Loss=43.7401, TotalTime=0.43h
🔧 SNR=19.0, inner channel c=4, ratio=0.08


Epoch:   0%|          | 1/500 [00:03<25:26,  3.06s/it]

E0 L_train=516.3085 L_val=260.4738 LR=1.0000e-04 EpochTime=3.05s


Epoch:  20%|██        | 101/500 [05:01<19:39,  2.96s/it]

E100 L_train=41.9954 L_val=39.5714 LR=1.0000e-04 EpochTime=2.92s


Epoch:  40%|████      | 201/500 [10:05<15:46,  3.16s/it]

E200 L_train=29.4549 L_val=27.8564 LR=1.0000e-04 EpochTime=2.99s


Epoch:  60%|██████    | 301/500 [15:09<09:52,  2.98s/it]

E300 L_train=23.9149 L_val=22.5787 LR=1.0000e-04 EpochTime=2.91s


Epoch:  80%|████████  | 401/500 [20:16<05:05,  3.09s/it]

E400 L_train=21.2095 L_val=20.0221 LR=1.0000e-04 EpochTime=3.18s


Epoch: 100%|██████████| 500/500 [25:24<00:00,  3.05s/it]


Done: Train Loss=19.6870, Val Loss=19.0149, TotalTime=0.42h
🔧 SNR=13.0, inner channel c=4, ratio=0.08


Epoch:   0%|          | 1/500 [00:02<24:07,  2.90s/it]

E0 L_train=502.4230 L_val=258.1241 LR=1.0000e-04 EpochTime=2.89s


Epoch:  20%|██        | 101/500 [05:14<19:55,  3.00s/it]

E100 L_train=44.0586 L_val=45.7035 LR=1.0000e-04 EpochTime=2.98s


Epoch:  40%|████      | 201/500 [10:30<14:59,  3.01s/it]

E200 L_train=30.2612 L_val=28.6291 LR=1.0000e-04 EpochTime=2.99s


Epoch:  60%|██████    | 301/500 [15:46<10:24,  3.14s/it]

E300 L_train=25.8017 L_val=26.8183 LR=1.0000e-04 EpochTime=3.09s


Epoch:  80%|████████  | 401/500 [21:04<05:11,  3.14s/it]

E400 L_train=23.9164 L_val=23.2139 LR=1.0000e-04 EpochTime=3.08s


Epoch: 100%|██████████| 500/500 [26:16<00:00,  3.15s/it]


Done: Train Loss=22.5004, Val Loss=21.6957, TotalTime=0.44h
🔧 SNR=7.0, inner channel c=4, ratio=0.08


Epoch:   0%|          | 1/500 [00:03<27:14,  3.27s/it]

E0 L_train=507.4045 L_val=265.3708 LR=1.0000e-04 EpochTime=3.27s


Epoch:  20%|██        | 101/500 [05:22<21:07,  3.18s/it]

E100 L_train=54.3320 L_val=51.0412 LR=1.0000e-04 EpochTime=3.02s


Epoch:  40%|████      | 201/500 [10:33<15:36,  3.13s/it]

E200 L_train=42.8616 L_val=42.0990 LR=1.0000e-04 EpochTime=3.14s


Epoch:  60%|██████    | 301/500 [15:33<09:47,  2.95s/it]

E300 L_train=38.4170 L_val=38.2851 LR=1.0000e-04 EpochTime=2.95s


Epoch:  80%|████████  | 401/500 [20:30<04:50,  2.93s/it]

E400 L_train=36.4710 L_val=35.3556 LR=1.0000e-04 EpochTime=2.88s


Epoch: 100%|██████████| 500/500 [25:23<00:00,  3.05s/it]


Done: Train Loss=34.7790, Val Loss=33.3560, TotalTime=0.42h
🔧 SNR=4.0, inner channel c=4, ratio=0.08


Epoch:   0%|          | 1/500 [00:03<25:17,  3.04s/it]

E0 L_train=518.3998 L_val=274.7400 LR=1.0000e-04 EpochTime=3.04s


Epoch:  20%|██        | 101/500 [05:00<19:54,  2.99s/it]

E100 L_train=66.9075 L_val=63.2098 LR=1.0000e-04 EpochTime=3.03s


Epoch:  40%|████      | 201/500 [09:56<15:01,  3.01s/it]

E200 L_train=56.8489 L_val=54.2843 LR=1.0000e-04 EpochTime=3.01s


Epoch:  60%|██████    | 301/500 [14:53<09:45,  2.94s/it]

E300 L_train=52.3274 L_val=51.8557 LR=1.0000e-04 EpochTime=2.88s


Epoch:  80%|████████  | 401/500 [19:50<04:54,  2.97s/it]

E400 L_train=49.8182 L_val=47.8092 LR=1.0000e-04 EpochTime=2.93s


Epoch: 100%|██████████| 500/500 [24:43<00:00,  2.97s/it]


Done: Train Loss=48.1030, Val Loss=45.9950, TotalTime=0.41h
🔧 SNR=1.0, inner channel c=4, ratio=0.08


Epoch:   0%|          | 1/500 [00:02<24:50,  2.99s/it]

E0 L_train=533.6301 L_val=286.0655 LR=1.0000e-04 EpochTime=2.98s


Epoch:  20%|██        | 101/500 [04:58<19:40,  2.96s/it]

E100 L_train=85.2541 L_val=82.2455 LR=1.0000e-04 EpochTime=2.97s


Epoch:  40%|████      | 201/500 [09:56<14:57,  3.00s/it]

E200 L_train=76.0089 L_val=72.0585 LR=1.0000e-04 EpochTime=3.00s


Epoch:  60%|██████    | 301/500 [14:52<09:45,  2.94s/it]

E300 L_train=71.2931 L_val=68.2564 LR=1.0000e-04 EpochTime=2.91s


Epoch:  80%|████████  | 401/500 [19:50<04:55,  2.99s/it]

E400 L_train=68.2096 L_val=64.4565 LR=5.0000e-05 EpochTime=3.04s


Epoch: 100%|██████████| 500/500 [24:43<00:00,  2.97s/it]


Done: Train Loss=66.4134, Val Loss=63.2677, TotalTime=0.41h
