In [None]:
!pip install einops



In [1]:

import os
import torch
import torch.nn.functional as F
import copy
import random
import numpy as np
import math
import csv
import warnings
import tqdm
from torch.utils.data import DataLoader
from model.kbynet import YOLO
from model.metrics import EMA, AverageMeter, ComputeLoss, psnr,ssim
from model.dataset import YOLODataset, collateFunction
from model.utils import clip_gradients,strip_optimizer,scale,box_iou,compute_ap,non_max_suppression


In [13]:


def setup_seed():
    """
    Setup random seed.
    """
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def setup_multi_processes():
    """
    Setup multi-processing environment variables.
    """
    import cv2
    from os import environ
    from platform import system

    # set multiprocess start method as `fork` to speed up the training
    if system() != 'Windows':
        torch.multiprocessing.set_start_method('fork', force=True)

    # disable opencv multithreading to avoid system being overloaded
    cv2.setNumThreads(0)

    # setup OMP threads
    if 'OMP_NUM_THREADS' not in environ:
        environ['OMP_NUM_THREADS'] = '1'

    # setup MKL threads
    if 'MKL_NUM_THREADS' not in environ:
        environ['MKL_NUM_THREADS'] = '1'




In [14]:


warnings.filterwarnings("ignore")

def custom_exponential_lr(gamma):
    return lambda epoch:  gamma ** epoch

def one_cycle(y1=0.0, y2=1.0, steps=100):
    """Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
    return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1

def learning_rate(args):
    def fn(x):
        return (1 - x / args['epochs']) * (1.0 - args['lrf']) + args['lrf']

    return fn


def train(args):
    # Model
    checkpoint_path = '/kaggle/input/vocfogalst/none.pt'
    optimizer_state_dict = None
    amp_scale_state_dict = None
    start_epoch = 0
    best = 0
    ema_state_dict = None
    scheduler_state_dict = None
    depth = [1, 2, 2]
    width = [3, 32, 64, 128, 256, 512]
    model = YOLO(width, depth, args['nc']).cuda()
    weight_opt = None
    if os.path.exists(checkpoint_path):
        ckpt = torch.load(checkpoint_path)
        model_state_dict = ckpt['model'].state_dict()
        model.load_state_dict(model_state_dict)
        optimizer_state_dict = ckpt['optimizer']
        amp_scale_state_dict = ckpt['amp_scale']
        start_epoch = ckpt['epoch'] + 1
        best = ckpt['best']
        ema_state_dict = ckpt['ema'].state_dict()
        ema_updates = ckpt['ema_updates']
        scheduler_state_dict = ckpt['scheduler']
#         weight_opt = ckpt['weight_opt']
        print(f'Resuming training from epoch {start_epoch}')
    else:
        # Load the entire state dict
        # Filter out keys that contain 'head'
        device          = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model_dict      = model.state_dict()
        pretrained_dict = torch.load("/kaggle/input/utilis2/yolov8_s.pt", map_location = device)
        items = list(pretrained_dict.items())
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    ema = EMA(model)
    if ema_state_dict is not None:
        print('Load ema')
        ema.ema.load_state_dict(ema_state_dict)
        ema.updates = ema_updates

    accumulate = max(round(64 / (args['batch_size'] )), 1)
    args['weight_decay'] *= args['batch_size']  * accumulate / 64
    print(args['weight_decay'])
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = torch.nn.DataParallel(model)

    UnFreeze_flag = False

    param_names = [name for name, param in model.named_parameters() if param.requires_grad]
    layers_to_freeze = param_names[:69]
    if args['Freeze_Train']:
        print('Freezing Layer...')
        for k, v in model.named_parameters():
            if k in layers_to_freeze:
                v.requires_grad = False
                print(k)


    def get_param_group(i, k):
        if i <= 98:  # backbone layers
            return 'backbone'
        elif 99 <= i <= 374:  # specific layers
            return 'res'
        else:
            return 'det'

    # Define dictionaries to hold parameters of different layers and types
    params = {'backbone': {'bias': [], 'norm': [], 'others': []},
              'res': {'bias': [], 'norm': [], 'others': []},
              'det': {'bias': [], 'norm': [], 'others': []}}

#     # Loop over named parameters
    for i, (k, v) in enumerate(model.named_parameters()):
        group = get_param_group(i, k)
        if 'bias' in k or 'b_custom' in k:
            params[group]['bias'].append(v)
            param_type = 'bias'
        elif 'norm' in k:
            params[group]['norm'].append(v)
            param_type = 'norm'
        else:
            params[group]['others'].append(v)
            param_type = 'others'
        print(f"Layer {i}: {k}, Group: {group}, Type: {param_type}")  # print layer name, group, and type




    optimizer = torch.optim.SGD(params['backbone']['bias'], lr=args['lr0'], momentum=args['momentum'], weight_decay=0)

    optimizer.add_param_group({'params': params['backbone']['norm'], 'lr': args['lr0'],'weight_decay': 0})
    optimizer.add_param_group({'params': params['backbone']['others'], 'lr': args['lr0'], 'weight_decay': args['weight_decay']})
    optimizer.add_param_group({'params': params['res']['bias'], 'lr': args['lr1'], 'weight_decay': 0})
    optimizer.add_param_group({'params': params['res']['norm'], 'lr': args['lr1'],'weight_decay': 0})
    optimizer.add_param_group({'params': params['res']['others'],'lr': args['lr1'], 'weight_decay': args['weight_decay']})
    optimizer.add_param_group({'params': params['det']['bias'], 'lr': args['lr2'], 'weight_decay': 0})
    optimizer.add_param_group({'params': params['det']['norm'],'lr': args['lr2'], 'weight_decay': 0})
    optimizer.add_param_group({'params': params['det']['others'], 'lr': args['lr2'], 'weight_decay': args['weight_decay']})


    if optimizer_state_dict is not None:
        print('Load optimizier')
        optimizer.load_state_dict(optimizer_state_dict)

    del params

    lr0 = one_cycle(1, args['lrf'], args['cosine_epochs'])
    lr1 = one_cycle(1, args['lrf'], args['cosine_epochs'])
    lr2 = one_cycle(1, args['lrf'], args['cosine_epochs'])

    train_dir = '/kaggle/input/vocfog/vocfog/train'


    dataset = YOLODataset(train_dir, args['input_size'], args, True)
    loader = DataLoader(dataset, batch_size=args['batch_size'], num_workers=4,
                                   shuffle=True,collate_fn=collateFunction)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lr0]*3+[lr1]*3+[lr2]*3, last_epoch=start_epoch-1)


    if scheduler_state_dict is not None:
        print('Load scheduler')
        scheduler.load_state_dict(scheduler_state_dict)
        print(scheduler_state_dict)


    # Start training
    num_batch = len(loader)
    amp_scale = torch.cuda.amp.GradScaler()
    if amp_scale_state_dict is not None:
        print('Load scaler')

        amp_scale.load_state_dict(amp_scale_state_dict)

    criterion = ComputeLoss(model, args)



    num_warmup = max(round(args['warmup_epochs'] * num_batch), 1000)

    with open('step.csv', 'w') as f:
        writer = csv.DictWriter(f, fieldnames=['epoch', 'mAP@50', 'mAP','PSNR','SSIM', 'total_loss','det_loss','res_loss','val_loss','val_det','val_res'])
        writer.writeheader()
        for epoch in range(start_epoch, args['epochs']):
            model.train()
            if epoch >= args['Freeze_Epoch'] and not UnFreeze_flag and args['Freeze_Train']:
                print('unfreeze')

                for k, v in model.named_parameters():
                    if  k in layers_to_freeze:
                        v.requires_grad = True


                UnFreeze_flag = True
            m_loss = AverageMeter()
            r_loss = AverageMeter()
            d_loss = AverageMeter()
            p_bar = enumerate(loader)
            print(('\n' + '%10s' * 3) % ('epoch', 'memory', 'loss'))
            p_bar = tqdm.tqdm(p_bar, total=num_batch)  # progress bar

            optimizer.zero_grad()

            for i, (samples,norain, targets,shapes) in p_bar:
                x = i + num_batch * epoch  # number of iterations
                samples = samples.float()
                norain = norain.float()
                samples = samples.cuda()
                norain = norain.cuda()
                targets = targets.cuda()

                pad_w, pad_h = shapes[0][1][1]
                pad_w, pad_h = int(pad_w), int(pad_h)
                _, _, height, width = samples.shape

                if x <= num_warmup:
                    xp = [0, num_warmup]
                    fp = [1, 64 / (args['batch_size'] * args['world_size'])]
                    accumulate = max(1, np.interp(x, xp, fp).round())
                    for j, y in enumerate(optimizer.param_groups):
                        if j == 0 or j == 3 or j == 6:
                            fp = [args['warmup_bias_lr'], y['initial_lr'] * lr1(epoch)]
                        else:
                            fp = [0.0, y['initial_lr'] * lr1(epoch)]
                        y['lr'] = np.interp(x, xp, fp)
                        if 'momentum' in y:
                            fp = [args['warmup_momentum'], args['momentum']]
                            y['momentum'] = np.interp(x, xp, fp)


                # Forward
                with torch.cuda.amp.autocast(enabled=True ):
                    outputs = model(samples)  # forward
                    det_loss = criterion(outputs['Detection'], targets)
                    res_loss = F.l1_loss(outputs['Restoration'], norain)


                    loss = (args['det']*det_loss) + (args['res']*res_loss) 


                m_loss.update(loss.item(), samples.size(0))
                d_loss.update(det_loss.item(), samples.size(0))
                r_loss.update(res_loss.item(), samples.size(0))



                amp_scale.scale(loss).backward()

                
                if x % accumulate == 0:
                    amp_scale.unscale_(optimizer)  # unscale gradients
                    clip_gradients(model)  # clip gradients
                    amp_scale.step(optimizer)  # optimizer.step
                    amp_scale.update()
                    optimizer.zero_grad()
                    if ema:
                        ema.update(model)



                memory = f'{torch.cuda.memory_reserved() / 1E9:.3g}G'  # (GB)
                s = ('%10s' * 2 + '%10.4g' + '%10.4g' + '%10.4g' +'%10.4g' ) % (f"{epoch + 1}/{args['epochs']}", memory, m_loss.avg, d_loss.avg, r_loss.avg,
                                                                               scheduler.optimizer.param_groups[0]['lr'])
                p_bar.set_description(s)
                del loss
                del det_loss
                del res_loss
                del outputs



            scheduler.step()
            if args['local_rank'] == 0:
                # mAP
                last = test(args, ema.ema,criterion,weight_opt)
                writer.writerow({'mAP': str(f'{last[1]:.3f}'),
                                 'epoch': str(epoch + 1).zfill(3),
                                 'mAP@50': str(f'{last[0]:.3f}'),
                                  'PSNR': str(f'{last[2]:.3f}'),
                                  'SSIM': str(f'{last[3]:.3f}'),
                                 'total_loss': str(f'{m_loss.avg:.3f}'),
                                 'det_loss': str(f'{d_loss.avg:.3f}'),
                                 'res_loss': str(f'{r_loss.avg:.3f}'),
                                 'val_loss': str(f'{last[4]:.3f}'),
                                 'val_det': str(f'{last[5]:.3f}'),
                                 'val_res': str(f'{last[6]:.3f}'),
                                })
                f.flush()

                # Update best mAP
                if last[0] > best:
                    best = last[0]

                # Save model
                ckpt = {
                    'model': copy.deepcopy(model.module).half(),
                    'optimizer': optimizer.state_dict(),
                    'amp_scale': amp_scale.state_dict(),
                    'epoch': epoch,
                    'best': best,
                    'ema': copy.deepcopy(ema.ema).half(),  # Save EMA state
                    'ema_updates': ema.updates,  # Save updates
                    'scheduler': scheduler.state_dict(),  # Save scheduler state

                }
                # Save last, best and delete
                torch.save(ckpt, 'last.pt')
                if best == last[0]:
                    torch.save(ckpt, 'best.pt')
                del ckpt

    if args['local_rank'] == 0:
        strip_optimizer('best.pt')  # strip optimizers
        strip_optimizer('last.pt')  # strip optimizers

    torch.cuda.empty_cache()



@torch.no_grad()
def test(args, model=None,criterion=None,weight_opt=None):
    val_dir = '/kaggle/input/vocfog/vocfog/test'
    val_data = YOLODataset(val_dir, args['input_size'], args, False)

    val_loader = DataLoader(val_data, batch_size=args['batch_size'], num_workers=4, shuffle=True,collate_fn=collateFunction)

    if model is None:
        model = torch.load('/kaggle/input/weights/best (1).pt', map_location='cuda')['ema'].float()

    if criterion is None:
        criterion = ComputeLoss(model, args)

    model.half()
    model.eval()

    # Configure
    iou_v = torch.linspace(0.5, 0.95, 10).cuda()  # iou vector for mAP@0.5:0.95
    n_iou = iou_v.numel()
    psnr_total = 0.
    ssim_total = 0.
    num_samples = 0
    m_pre = 0.
    m_rec = 0.
    map50 = 0.
    mean_ap = 0.
    metrics = []
    m_loss = AverageMeter()
    r_loss = AverageMeter()
    d_loss = AverageMeter()

    p_bar = tqdm.tqdm(val_loader, desc=('%10s' * 3) % ('precision', 'recall', 'mAP'))
    for samples,norain, targets,shapes in p_bar:
        samples = samples.cuda()
        targets = targets.cuda()
        norain = norain.cuda()
        samples = samples.half()  # uint8 to fp16/32
        _, _, height, width = samples.shape  # batch size, channels, height, width
        pad_w, pad_h = shapes[0][1][1]
        pad_w, pad_h = int(pad_w), int(pad_h)
        # Inference
        outputs = model(samples)
        det_loss = criterion(outputs['Detection'][0], targets)
        res_loss = F.l1_loss(outputs['Restoration'], norain)

        loss = (args['det'] * det_loss) + (args['res'] * res_loss)



        m_loss.update(loss.item(), samples.size(0))
        d_loss.update(det_loss.item(), samples.size(0))
        r_loss.update(res_loss.item(), samples.size(0))



        restoration = outputs['Restoration']

        for i in range(samples.shape[0]):
            pad_h, pad_w = shapes[i][1][1]
            pad_h, pad_w = int(pad_h), int(pad_w)

            # Apply padding to each image individually
            restored_img = restoration[i, :, pad_h:height-pad_h, pad_w:width-pad_w]
            norain_img = norain[i, :, pad_h:height-pad_h, pad_w:width-pad_w]

            # Clamp and convert to byte for each image individually
            restored_img = torch.clamp(restored_img.mul(255), 0, 255).byte()
            norain_img = torch.clamp(norain_img.mul(255), 0, 255).byte()
            restored_img = restored_img.unsqueeze(0)
            norain_img = norain_img.unsqueeze(0)
            # Calculate PSNR and SSIM for each pair of images
            psnr_total += psnr(norain_img, restored_img)
            ssim_total += ssim(norain_img, restored_img)
            num_samples += 1

        targets[:, 2:] *= torch.tensor((width, height, width, height)).cuda()  # to pixels
        det_outputs = non_max_suppression(outputs['Detection'][1], 0.001, 0.65)
        # Metrics
        for i, output in enumerate(det_outputs):

            labels = targets[targets[:, 0] == i, 1:]
            correct = torch.zeros(output.shape[0], n_iou, dtype=torch.bool).cuda()
            if output.shape[0] == 0:
                if labels.shape[0]:
                    metrics.append((correct, *torch.zeros((3, 0)).cuda()))
                continue

            detections = output.clone()
            scale(detections[:, :4], samples[i].shape[1:], shapes[i][0], shapes[i][1])
            # Evaluate
            if labels.shape[0]:
                tbox = labels[:, 1:5].clone()  # target boxes
                tbox[:, 0] = labels[:, 1] - labels[:, 3] / 2  # top left x
                tbox[:, 1] = labels[:, 2] - labels[:, 4] / 2  # top left y
                tbox[:, 2] = labels[:, 1] + labels[:, 3] / 2  # bottom right x
                tbox[:, 3] = labels[:, 2] + labels[:, 4] / 2  # bottom right y
                scale(tbox, samples[i].shape[1:], shapes[i][0], shapes[i][1])

                correct = np.zeros((detections.shape[0], iou_v.shape[0]))
                correct = correct.astype(bool)

                t_tensor = torch.cat((labels[:, 0:1], tbox), 1)
                iou = box_iou(t_tensor[:, 1:], detections[:, :4])
                correct_class = t_tensor[:, 0:1] == detections[:, 5]
                for j in range(len(iou_v)):
                    x = torch.where((iou >= iou_v[j]) & correct_class)
                    if x[0].shape[0]:
                        matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1)
                        matches = matches.cpu().numpy()
                        if x[0].shape[0] > 1:
                            matches = matches[matches[:, 2].argsort()[::-1]]
                            matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
                            matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
                        correct[matches[:, 1].astype(int), j] = True
                correct = torch.tensor(correct, dtype=torch.bool, device=iou_v.device)

            metrics.append((correct, output[:, 4], output[:, 5], labels[:, 0]))

    # Compute metrics
    metrics = [torch.cat(x, 0).cpu().numpy() for x in zip(*metrics)]  # to numpy
    if len(metrics) and metrics[0].any():
        tp, fp, m_pre, m_rec, map50, mean_ap = compute_ap(*metrics)

    # Print results
    print('%10.4g' * 4 % (m_pre, m_rec, map50, mean_ap))
    avg_psnr = psnr_total / num_samples
    avg_ssim = ssim_total / num_samples

    print('Average PSNR: %.4f' % avg_psnr)
    print('Average SSIM: %.4f' % avg_ssim)
    print('Val Det Loss: %.4f' % d_loss.avg)
    # Return results
    model.float()  # for training
    return map50, mean_ap, avg_psnr, avg_ssim, m_loss.avg,d_loss.avg,r_loss.avg


def main():
    args ={
        'input_size':640,
        'batch_size':16,
        'local_rank':0,
        'epochs': 100,
        'cosine_epochs':100,
        'world_size': 1,
        'lr0':1e-4,
        'lr1':1e-2,
        'lr2':1e-4,
        'lrf':1e-2,
        'lrf2':1e-2,
        'weight_decay':5e-4,
        'warmup_bias_lr':0,
        'Freeze_Epoch':0,
        'Freeze_Train': False,
        'warmup_epochs': 7,
        'warmup_momentum': 0.8,
        'momentum':0.93700000,
        'nc':5,
        'box': 7.5 ,                     # box loss gain
        'cls': 0.5,                      # cls loss gain
        'dfl': 1.5 ,
        'hsv_h': 0.015000,               # image HSV-Hue augmentation (fraction)
        'hsv_s': 0.700000,              # image HSV-Saturation augmentation (fraction)
        'hsv_v': 0.400000,               # image HSV-Value augmentation (fraction)
        'degrees': 0.0000,               # image rotation (+/- deg)
        'translate': 0.10,               # image translation (+/- fraction)
        'scale': 0.500000,               # image scale (+/- gain)
        'shear': 0.000000,               # image shear (+/- deg)
        'flip_ud': 0.0000,               # image flip up-down (probability)
        'flip_lr': 0.5000,               # image flip left-right (probability)
        'mosaic': 0.00000,               # image mosaic (probability)
        'mix_up': 0.00000,               # image mix-up (probability)
        'det': 0.6,
        'res': 0.4
    }

    setup_seed()
    setup_multi_processes()



    train(args)
#     test(args, params)





In [15]:
main()

FileNotFoundError: [Errno 2] No such file or directory: '/kaggle/input/utilis2/yolov8_s.pt'

In [None]:
# depth = [1, 2, 2]
# width = [3, 32, 64, 128, 256, 512]
# model = YOLO(width, depth, 5).cuda()
# for i,(k,v) in enumerate(model.named_parameters()):
#     print(f'{i}: {k}')

In [None]:
# # # Define the mean and std
# args ={
#     'input_size':640,
#     'batch_size':8,
#     'local_rank':0,
#     'epochs': 100,
#     'cosine_epochs':100,
#     'world_size': 1,
#     'lr0':1e-2,
#     'lr1':1e-2,
#     'lr2':1e-2,
#     'lrf':1e-2,
#     'lrf2':1e-2,
#     'weight_decay':5e-4,
#     'warmup_bias_lr':0,
#     'Freeze_Epoch':0,
#     'Freeze_Train': False,
#     'warmup_epochs': 3,
#     'warmup_momentum': 0.8,
#     'momentum':0.93700000,
#     'nc':5,
#     'box': 7.5 ,                     # box loss gain
#     'cls': 0.5,                      # cls loss gain
#     'dfl': 1.5 ,
#     'hsv_h': 1.000000,               # image HSV-Hue augmentation (fraction)
#     'hsv_s': 0.700000,              # image HSV-Saturation augmentation (fraction)
#     'hsv_v': 0.400000,               # image HSV-Value augmentation (fraction)
#     'degrees': 0.0000,               # image rotation (+/- deg)
#     'translate': 0.10,               # image translation (+/- fraction)
#     'scale': 0.500000,               # image scale (+/- gain)
#     'shear': 5.000000,               # image shear (+/- deg)
#     'flip_ud': 0.5000,               # image flip up-down (probability)
#     'flip_lr': 0.5000,               # image flip left-right (probability)
#     'mosaic': 0.00000,               # image mosaic (probability)
#     'mix_up': 0.00000,               # image mix-up (probability)
# }
# train_dir = '/kaggle/input/vocfog/vocfog/train'
    

# train_data = YOLODataset(train_dir, 512, args, True)
# train_loader = DataLoader(train_data, batch_size=1, num_workers=4,
#                                shuffle=True,collate_fn=collateFunction)

# for i, (rainy_images, clean_images, targets,shapes) in enumerate(train_loader):
#     print(f'Train batch {i+1}:')
#     print(f'Rainy images shape: {rainy_images.shape}')
#     _, _, height, width = rainy_images.shape
#     # Denormalize the first image in the batch
#     pad_w, pad_h = shapes[0][1][1]
#     pad_w, pad_h = int(pad_w), int(pad_h)
#     nh=shapes[0][0][0] * shapes[0][1][0][0]
#     nw=shapes[0][0][1]*shapes[0][1][0][1]
#     # Display the first image in the batch
#     targets[:, 2:] *= torch.tensor((width, height, width, height)) # to pixels
#     labels = targets[targets[:, 0] == 0, 1:]
#     tbox = labels[:, 1:5].clone()  # target boxes
#     tbox[:, 0] = labels[:, 1] - labels[:, 3] / 2  # top left x
#     tbox[:, 1] = labels[:, 2] - labels[:, 4] / 2  # top left y
#     tbox[:, 2] = labels[:, 1] + labels[:, 3] / 2  # bottom right x
#     tbox[:, 3] = labels[:, 2] + labels[:, 4] / 2  # bottom right y
#     scale(tbox, (height,width), (nh,nw))
#     fig, ax = plt.subplots(1,2, figsize=(16,8))

#     rainy_image = rainy_images[0][:, pad_h:height-pad_h, pad_w:width-pad_w]
#     ax[0].imshow(rainy_image.permute(1, 2, 0).clamp(0, 1))

#     for box in tbox:
#         rect = patches.Rectangle((box[0], box[3]), box[2]-box[0], box[1]-box[3], linewidth=1, edgecolor='r', facecolor='none')
#         ax[0].add_patch(rect)

#     clean_image = clean_images[0][:, pad_h:height-pad_h, pad_w:width-pad_w]
#     ax[1].imshow(clean_image.permute(1, 2, 0).clamp(0, 1))

#     for box in tbox:
#         rect = patches.Rectangle((box[0], box[3]), box[2]-box[0], box[1]-box[3], linewidth=1, edgecolor='r', facecolor='none')
#         ax[1].add_patch(rect)
#     plt.tight_layout()

#     plt.show()
# #     rainy_image = rainy_images[1] * std + mean
# #     print(annotations)
# #     # Display the first image in the batch
# #     fig, ax = plt.subplots(1)
# #     ax.imshow(rainy_image.permute(1, 2, 0).clamp(0, 1))

#     # Draw the bounding boxes
# #     for box in annotations:
# #         box = box[2:]
# #         rect = patches.Rectangle(((box[0]-box[2]/2)*512, (box[1]-box[3]/2)*512), box[2]*512, box[3]*512, linewidth=1, edgecolor='r', facecolor='none')
# #         ax.add_patch(rect)

# #     plt.show()

#     if i == 3:  # Stop after 2 batches
#         break