In [2]:
import argparse
import datetime
import json
import os
import time
from pathlib import Path

import kornia.augmentation as K
import numpy as np
import timm
import timm.optim.optim_factory as optim_factory
import torch
import torch.backends.cudnn as cudnn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.transforms as tv_transforms
import wandb
import yaml
from kornia.augmentation import AugmentationSequential
from kornia.constants import Resample
from torch.utils.data import Subset
from torch.utils.tensorboard import SummaryWriter
from wandb_log import WANDB_LOG_IMG_CONFIG

import models_vit
import models_mae
import models_vit_segmentation
import util.lr_decay as lrd
import util.misc as misc
#from dataloaders.utils import get_dataset_and_sampler, get_eval_dataset_and_transform
from engine_finetune import evaluate, train_one_epoch
from lib.transforms import CustomCompose
from PIL import Image
from timm.models.layers import trunc_normal_
from util.lars import LARS
from util.misc import NativeScalerWithGradNormCount as NativeScaler

from marinedebrisdetector.marinedebrisdetector.data.marinedebrisdatamodule import MarineDebrisDataModule

Image.MAX_IMAGE_PIXELS = 1000000000




In [3]:
def get_args_parser():
    parser = argparse.ArgumentParser(
        "MAE linear probing for image classification", add_help=False
    )

    parser.add_argument(
        "--checkpoint_interval", default=20, type=int, help="How often to checkpoint"
    )
    parser.add_argument(
        "--batch_size",
        default=64,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )

    parser.add_argument(
        "--print_freq",
        default=20,
        type=int,
        help="How often (iters) print results to wandb",
    )
    parser.add_argument(
        "--drop_path",
        type=float,
        default=0.1,
        metavar="PCT",
        help="Drop path rate (default: 0.1)",
    )
    parser.add_argument("--epochs", default=400, type=int)
    parser.add_argument(
        "--accum_iter",
        default=1,
        type=int,
        help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)",
    )
    parser.add_argument("--config", default="config.yaml", type=str, help="Config file")
    parser.add_argument("--name", default="", type=str, help="Name of wandb entry")

    # Model parameters
    parser.add_argument(
        "--model",
        default="vit_large_patch16",
        type=str,
        metavar="MODEL",
        help="Name of model to train",
    )

    parser.add_argument("--linear_layer_scale", default=1.0, type=float, help="")

    # Model parameters
    parser.add_argument(
        "--wandb_id", default=None, type=str, help="Wandb id, useful for resuming runs"
    )

    parser.add_argument("--input_size", default=128, type=int, help="images input size")
    parser.add_argument(
        "--target_size", nargs="*", type=int, help="images input size", default=[128]
    )

    parser.add_argument(
        "--source_size", nargs="*", type=int, help="images source size", default=[128]
    )

    parser.add_argument(
        "--mask_ratio",
        default=0.75,
        type=float,
        help="Masking ratio (percentage of removed patches).",
    )

    parser.add_argument("--scale_min", default=0.5, type=float, help="Min RRC scale")

    parser.add_argument("--scale_max", default=1.0, type=float, help="Max RRC scale")

    parser.add_argument(
        "--norm_pix_loss",
        action="store_true",
        help="Use (per-patch) normalized pixels as targets for computing loss",
    )

    parser.add_argument(
        "--restart",
        action="store_true",
        help="Load the checkpoint, but start from epoch 0",
    )
    parser.set_defaults(norm_pix_loss=False)

    # Optimizer parameters
    parser.add_argument(
        "--weight_decay", type=float, default=0.05, help="weight decay (default: 0.05)"
    )

    parser.add_argument(
        "--lr",
        type=float,
        default=None,
        metavar="LR",
        help="learning rate (absolute lr)",
    )
    parser.add_argument(
        "--blr",
        type=float,
        default=1e-3,
        metavar="LR",
        help="base learning rate: absolute_lr = base_lr * total_batch_size / 256",
    )
    parser.add_argument(
        "--layer_decay",
        type=float,
        default=0.75,
        help="layer-wise lr decay from ELECTRA/BEiT",
    )
    parser.add_argument(
        "--min_lr",
        type=float,
        default=0.0,
        metavar="LR",
        help="lower lr bound for cyclic schedulers that hit 0",
    )

    parser.add_argument(
        "--warmup_epochs", type=int, default=0, metavar="N", help="epochs to warmup LR"
    )

    parser.add_argument(
        "--output_dir",
        default="./output_dir",
        help="path where to save, empty for no saving",
    )
    parser.add_argument(
        "--log_dir", default="./output_dir", help="path where to tensorboard log"
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--eval_only", action="store_true", help="Only do KNN Eval")
    parser.add_argument(
        "--eval_dataset",
        default="resisc",
        type=str,
        help="name of eval dataset to use. Options are resisc (default), airound, mlrsnet, and fmow.",
    )
    parser.add_argument(
        "--eval_path", default="resisc45", type=str, help="dataset path"
    )
    parser.add_argument(
        "--eval_gsd",
        action="store_true",
        help="USE GSD Relative Embedding with base=224x224",
    )
    parser.add_argument(
        "--eval_base_resolution",
        default=1.0,
        type=float,
        help="Global Multiplication factor of Positional Embedding Resolution in KNN",
    )
    parser.add_argument(
        "--eval_reference_resolution",
        default=224,
        type=float,
        help="Reference input resolution to scale GSD factor by in eval.",
    )
    parser.add_argument(
        "--eval_scale", default=224, type=int, help="The size of the eval input."
    )
    parser.add_argument("--eval", action="store_true", help="Perform evaluation only")
    parser.add_argument(
        "--dist_eval",
        action="store_true",
        default=False,
        help="Enabling distributed evaluation (recommended during training for faster monitor",
    )
    parser.set_defaults(eval_only=False)
    parser.add_argument(
        "--no_autoresume",
        action="store_true",
        help="Dont autoresume from last checkpoint",
    )
    parser.set_defaults(no_autoresume=False)
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--resume", default="", help="resume from checkpoint")

    parser.add_argument(
        "--start_epoch", default=0, type=int, metavar="N", help="start epoch"
    )
    parser.add_argument("--num_workers", default=10, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )
    parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem")
    parser.set_defaults(pin_mem=True)

    # distributed training parameters
    parser.add_argument(
        "--world_size", default=1, type=int, help="number of distributed processes"
    )
    parser.add_argument("--local_rank", default=-1, type=int)
    parser.add_argument("--dist_on_itp", action="store_true")
    parser.add_argument(
        "--dist_url", default="env://", help="url used to set up distributed training"
    )

    parser.add_argument(
        "--base_resolution",
        default=2.5,
        type=float,
        help="The base resolution to use for the period of the sin wave for positional embeddings",
    )

    # * Finetuning params
    parser.add_argument(
        "--finetune",
        default=True,
        help="If true, finetune. If false, linear probe.",
    )
    parser.add_argument(
        "--checkpoint_path", default='/home/emanuele/jobs/adopt/scale-MAE-marinelitter/IGARSS2024/scalemae-vitlarge-800.pth', type=str, help="Path to checkpoint weights."
    )
    parser.add_argument("--global_pool", action="store_true")
    parser.set_defaults(global_pool=False)
    parser.add_argument(
        "--cls_token",
        action="store_false",
        dest="global_pool",
        help="Use class token instead of global pool for classification",
    )

    parser.add_argument(
        "--nb_classes",
        default=1000,
        type=int,
        help="number of the classification types",
    )

    parser.add_argument('--data-path', type=str, default="/home/emanuele/data/ADOPT/marinedebris")
    parser.add_argument('--image-size', type=int, default=128)
    parser.add_argument('--download', action="store_false")
    parser.add_argument('--no-label-refinement', action="store_false")
    parser.add_argument('--no-s2ships', action="store_false")
    parser.add_argument('--no-marida', action="store_true")

    return parser

In [4]:
args = get_args_parser()
args = args.parse_args()
if args.output_dir:
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

In [5]:
misc.init_distributed_mode(args)
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()

#print(f"job dir: {os.path.dirname(os.path.realpath(__file__))}")
#print(f"{args}".replace(", ", ",\n"))

device = torch.device(args.device)

# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

cudnn.benchmark = True

Not using distributed mode


In [6]:
######## backwards compatability hacks
if not isinstance(args.target_size, list):
    args.target_size = [args.target_size]

if not isinstance(args.source_size, list):
    args.source_size = [args.source_size]
########################################

In [7]:
# Validate that all sizes in target_size are multiples of 16
if len(args.target_size) > 0:
    assert all(
        [type(i) == int for i in args.target_size]
    ), "Invalid multiscale input, it should be a json list of int, e.g. [224,448]"
    assert all(
        [i % 16 == 0 for i in args.target_size]
    ), "Decoder resolution must be a multiple of patch size (16)"


In [8]:
# set a random wandb id before fixing random seeds
random_wandb_id = wandb.util.generate_id()

In [9]:
marinedebris_datamodule = MarineDebrisDataModule(data_root=args.data_path,
                                        image_size=args.image_size,
                                        workers=args.num_workers,
                                        batch_size=args.batch_size,
                                        no_label_refinement=args.no_label_refinement,
                                        no_s2ships=args.no_s2ships,
                                        no_marida=args.no_marida,
                                        download=False)

marinedebris_datamodule.prepare_data()
marinedebris_datamodule.setup()



Dataset Composition total

train
flobs_dataset (train): 39194
shipsdataset: 29833
MARIDA (train): 2017

val
refinedflobs_val: 2451
MARIDA (val): 1185

test
flobstestdataset: 2197
maridatestdataset: 872


Dataset Composition debris/non-debris
train 
flobs_dataset (train): 19587/19607
shipsdataset: 0/29833
MARIDA (train): 930/1087

val
refinedflobs_val: 868/1583
MARIDA (val): 616/569

test
flobstestdataset: 903/1294
maridatestdataset: 270/602


In [10]:
data_loader_train = marinedebris_datamodule.train_dataloader()
data_loader_val = marinedebris_datamodule.val_dataloader()

In [10]:
'''for image, mask, _ in data_loader_val:
    print(image.shape)'''

'for image, mask, _ in data_loader_val:\n    print(image.shape)'

In [11]:
# sample_image = next(iter(dataset_train))

"""mean_ndvi = 0
mean_fdi = 0
mean_grayscale = 0
x2_ndvi = 0
x2_fdi = 0
x2_grayscale = 0
i = 0

for image, mask, _ in dataset_train:
    print(i)
    i += 1
    grayscale = image[0]
    fdi = image[1]
    ndvi = image[2]
    mean_ndvi += torch.mean(ndvi)
    mean_fdi += torch.mean(fdi)
    mean_grayscale += torch.mean(grayscale)
    x2_ndvi += torch.sum(ndvi**2)
    x2_fdi += torch.sum(fdi**2)
    x2_grayscale += torch.sum(ndvi**2)
mean_ndvi = mean_ndvi/i
mean_fdi = mean_fdi/i
mean_grayscale = mean_grayscale/i
std_ndvi = torch.sqrt(x2_ndvi/(i*128*128)-mean_ndvi**2)
std_fdi = torch.sqrt(x2_fdi/(i*128*128)-mean_fdi**2)
std_greyscale = torch.sqrt(x2_grayscale/(i*128*128)-mean_grayscale**2)
print(mean_grayscale)
print(mean_ndvi)
print(mean_fdi)
print(std_greyscale)
print(std_ndvi)
print(std_fdi)"""




'mean_ndvi = 0\nmean_fdi = 0\nmean_grayscale = 0\nx2_ndvi = 0\nx2_fdi = 0\nx2_grayscale = 0\ni = 0\n\nfor image, mask, _ in dataset_train:\n    print(i)\n    i += 1\n    grayscale = image[0]\n    fdi = image[1]\n    ndvi = image[2]\n    mean_ndvi += torch.mean(ndvi)\n    mean_fdi += torch.mean(fdi)\n    mean_grayscale += torch.mean(grayscale)\n    x2_ndvi += torch.sum(ndvi**2)\n    x2_fdi += torch.sum(fdi**2)\n    x2_grayscale += torch.sum(ndvi**2)\nmean_ndvi = mean_ndvi/i\nmean_fdi = mean_fdi/i\nmean_grayscale = mean_grayscale/i\nstd_ndvi = torch.sqrt(x2_ndvi/(i*128*128)-mean_ndvi**2)\nstd_fdi = torch.sqrt(x2_fdi/(i*128*128)-mean_fdi**2)\nstd_greyscale = torch.sqrt(x2_grayscale/(i*128*128)-mean_grayscale**2)\nprint(mean_grayscale)\nprint(mean_ndvi)\nprint(mean_fdi)\nprint(std_greyscale)\nprint(std_ndvi)\nprint(std_fdi)'

In [12]:
'''dataset_train = marinedebris_datamodule.train_dataloader()
sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
batch_size_factor = 1
data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=int(args.batch_size * batch_size_factor),
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )
dataset_val = marinedebris_datamodule.val_dataloader()
sampler_val = torch.utils.data.DistributedSampler(
        dataset_val,
        num_replicas=num_tasks,
        rank=global_rank,
        shuffle=False,
        drop_last=False,
    )
data_loader_val = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=args.batch_size,
        sampler=sampler_val,
        num_workers=args.num_workers,
    )'''

'dataset_train = marinedebris_datamodule.train_dataloader()\nsampler_train = torch.utils.data.DistributedSampler(\n            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True\n        )\nbatch_size_factor = 1\ndata_loader_train = torch.utils.data.DataLoader(\n        dataset_train,\n        sampler=sampler_train,\n        batch_size=int(args.batch_size * batch_size_factor),\n        num_workers=args.num_workers,\n        pin_memory=args.pin_mem,\n        drop_last=True,\n    )\ndataset_val = marinedebris_datamodule.val_dataloader()\nsampler_val = torch.utils.data.DistributedSampler(\n        dataset_val,\n        num_replicas=num_tasks,\n        rank=global_rank,\n        shuffle=False,\n        drop_last=False,\n    )\ndata_loader_val = torch.utils.data.DataLoader(\n        dataset_val,\n        batch_size=args.batch_size,\n        sampler=sampler_val,\n        num_workers=args.num_workers,\n    )'

In [11]:
model = models_vit.__dict__[args.model](
    img_size=args.input_size,
    num_classes=args.nb_classes,
    global_pool=args.global_pool,
)
'''model = models_vit_segmentation.__dict__[args.model](
    img_size=args.input_size,
    num_classes=args.nb_classes,
    global_pool=args.global_pool,
)'''
'''model = models_mae.__dict__['mae_'+args.model](
    fixed_output_size = 0
)'''


"model = models_mae.__dict__['mae_'+args.model](\n    fixed_output_size = 0\n)"

In [12]:
if args.checkpoint_path:
    checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
    print('done')

    print(f"Load pre-trained checkpoint from: {args.checkpoint_path}")
    checkpoint_model = checkpoint["model"]
    state_dict = model.state_dict()
    for k in ["head.weight", "head.bias"]:
        if (
            k in checkpoint_model
            and checkpoint_model[k].shape != state_dict[k].shape
        ):
            print(f"Removing key {k} from pretrained checkpoint")
            del checkpoint_model[k]
    if args.input_size != 224:
        if (
            "pos_embed" in checkpoint_model
            and checkpoint_model["pos_embed"].shape != state_dict["pos_embed"].shape
        ):
            print(f"Removing key pos_embed from pretrained checkpoint")
            del checkpoint_model["pos_embed"]

    # interpolate position embedding
    # We do not do this in Scale-MAE since we use a resolution-specific
    # pos embedding in forward_features
    # interpolate_pos_embed(model, checkpoint_model)

    # load pre-trained model
    msg = model.load_state_dict(checkpoint_model, strict=False)
    print(msg)

    if args.global_pool:
        assert set(msg.missing_keys) == {
            "head.weight",
            "head.bias",
            "fc_norm.weight",
            "fc_norm.bias",
        }
    else:
        if args.input_size != 224:
            assert set(msg.missing_keys) == {
                "head.weight",
                "head.bias",
                "pos_embed",
            }
        else:
            assert set(msg.missing_keys) == {"head.weight", "head.bias"}

    if not args.eval:
        # manually initialize fc layer: following MoCo v3
        trunc_normal_(model.head.weight, std=0.01)
        # model.head.bias.data.zero_()

done
Load pre-trained checkpoint from: /home/emanuele/jobs/adopt/scale-MAE-marinelitter/IGARSS2024/scalemae-vitlarge-800.pth
Removing key pos_embed from pretrained checkpoint
_IncompatibleKeys(missing_keys=['pos_embed', 'head.weight', 'head.bias'], unexpected_keys=['mask_token', 'decoder_pos_embed', 'decoder_embed.weight', 'decoder_embed.bias', 'fpn.fpn1.0.weight', 'fpn.fpn1.0.bias', 'fpn.fpn1.1.ln.weight', 'fpn.fpn1.1.ln.bias', 'fpn.fpn1.3.weight', 'fpn.fpn1.3.bias', 'fpn.fpn2.0.weight', 'fpn.fpn2.0.bias', 'fcn_high.proj.weight', 'fcn_high.proj.bias', 'fcn_high.conv_blocks.0.convs.1.weight', 'fcn_high.conv_blocks.0.convs.1.bias', 'fcn_high.conv_blocks.0.convs.3.weight', 'fcn_high.conv_blocks.0.convs.3.bias', 'fcn_high.conv_blocks.0.residual.weight', 'fcn_high.conv_blocks.0.residual.bias', 'fcn_high.conv_blocks.1.convs.1.weight', 'fcn_high.conv_blocks.1.convs.1.bias', 'fcn_high.conv_blocks.1.convs.3.weight', 'fcn_high.conv_blocks.1.convs.3.bias', 'fcn_high.conv_blocks.1.residual.weight

In [13]:
patch_size = 16
out_chans = 1 # number of feature maps projected
'''model.head = torch.nn.Sequential(
    torch.nn.Linear(model.head.in_features, patch_size**2 * out_chans, bias=True),
    torch.nn.Conv2d(in_channels=model.head.in_features, out_channels=32, kernel_size=1, stride=1,
                                    padding='same'), torch.nn.LeakyReLU(0.1), torch.nn.Conv2d(in_channels=32, out_channels=16, kernel_size=1, stride=1,
                                    padding='same'), torch.nn.LeakyReLU(0.1),torch.nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1, stride=1,
                                    padding='same')
)'''
model.head = torch.nn.Sequential(
    torch.nn.Linear(model.head.in_features, patch_size**2 * out_chans, bias=True)
)


In [14]:
if not args.finetune:
    # Linear probe
    # freeze all but the head
    for _, p in model.named_parameters():
        p.requires_grad = False
    for _, p in model.head.named_parameters():
        p.requires_grad = True
else:
    for _, p in model.named_parameters():
        p.requires_grad = True
    for _, p in model.head.named_parameters():
        p.requires_grad = True

# ScaleMAE does not use the pos_embed within ViT
model.pos_embed.requires_grad = False

model.to(device)

VisionTransformer(
  (patch_embed): PatchEmbedUnSafe(
    (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-23): 24 x Block(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
  (head): 

In [15]:
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()

In [16]:
if args.lr is None:  # only base_lr is specified
    args.lr = args.blr * eff_batch_size / 256

'''if args.distributed:
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
    model_without_ddp = model.module'''

if not args.finetune:
    # Linear probe
    param_groups = optim_factory.param_groups_layer_decay(
        model_without_ddp, args.weight_decay
    )
else:
    # build optimizer with layer-wise lr decay (lrd)
    param_groups = lrd.param_groups_lrd(
        model_without_ddp,
        args.weight_decay,
        no_weight_decay_list=model_without_ddp.no_weight_decay(),
        layer_decay=args.layer_decay,
    )
    param_groups[-1]["lr_scale"] *= args.linear_layer_scale
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
print(optimizer)
loss_scaler = NativeScaler()

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.00025
    lr_scale: 0.0007525434581650003
    maximize: False
    weight_decay: 0.0

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.00025
    lr_scale: 0.0007525434581650003
    maximize: False
    weight_decay: 0.05

Parameter Group 2
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.00025
    lr_scale: 0.0010033912775533338
    maximize: False
    weight_decay: 0.0

Parameter Group 3
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.00025
    lr_scale: 0.0010033912775533338
    maximize: False
    we

In [17]:
criterion = torch.nn.BCEWithLogitsLoss()
print("criterion = %s" % str(criterion))

criterion = BCEWithLogitsLoss()


In [18]:
misc.load_model(
    args=args,
    model_without_ddp=model_without_ddp,
    optimizer=optimizer,
    loss_scaler=loss_scaler,
)

In [22]:
def evaluate(
    data_loader,
    model,
    device,
    eval_base_resolution=1.0,
    gsd_embed=False,
    eval_scale=512,
    reference_size=512,
):
    gsd_ratio = eval_base_resolution
    if gsd_embed:
        gsd_ratio = gsd_ratio * (reference_size / eval_scale)

    criterion = torch.nn.BCEWithLogitsLoss()

    metric_logger = misc.MetricLogger(delimiter="  ")
    header = "Test:"

    # switch to evaluation mode
    model.eval()

    for images, target, _ in data_loader:        
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)
        target = target.unsqueeze(1)

        # compute output
        with torch.cuda.amp.autocast():
            output = model(
                images,
                input_res=torch.ones(len(images)).float().to(images.device) * gsd_ratio,
            )
            print(output.shape)
            output = model.unpatchify(output,1)
            print(output.shape)
            print(target.shape)
            loss = criterion(output, target)

        #acc1, acc5 = accuracy(output, target, topk=(1, 5))

        
        print("* loss "+str(loss.item()))


args.eval = True



if args.eval:
    evaluate(
        data_loader_train,
        model,
        device,
        eval_base_resolution=args.eval_base_resolution,
        gsd_embed=args.eval_gsd,
        eval_scale=args.eval_scale,
        reference_size=args.eval_reference_resolution,
    )
    exit(0)

torch.Size([64, 3, 128, 128])
torch.Size([64, 64, 1024])
torch.Size([64, 65, 1024])
torch.Size([64, 65, 1024])
torch.Size([64, 65, 1024])
torch.Size([64, 65, 1024])
torch.Size([64, 64, 1024])
torch.Size([64, 64, 256])
torch.Size([64, 1, 128, 128])
torch.Size([64, 1, 128, 128])
* loss 0.7056969404220581
torch.Size([64, 3, 128, 128])
torch.Size([64, 64, 1024])
torch.Size([64, 65, 1024])
torch.Size([64, 65, 1024])


OutOfMemoryError: CUDA out of memory. Tried to allocate 18.00 MiB. GPU 0 has a total capacty of 23.69 GiB of which 6.69 MiB is free. Process 2480416 has 1.70 GiB memory in use. Including non-PyTorch memory, this process has 21.98 GiB memory in use. Of the allocated memory 19.51 GiB is allocated by PyTorch, and 2.15 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [33]:
def evaluate(
    data_loader,
    model,
    device,
    eval_base_resolution=1.0,
    gsd_embed=False,
    eval_scale=512,
    reference_size=512,
):
    gsd_ratio = eval_base_resolution
    if gsd_embed:
        gsd_ratio = gsd_ratio * (reference_size / eval_scale)

    criterion = torch.nn.BCEWithLogitsLoss()

    metric_logger = misc.MetricLogger(delimiter="  ")
    header = "Test:"

    # switch to evaluation mode
    model.eval()

    for (samples, labels, c) in metric_logger.log_every(data_loader, 10, header):
        images = samples
        target = labels
        c = c
        print('c')
        print(c.shape)
        c = c.to(device, non_blocking=True)
        
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # compute output
        with torch.cuda.amp.autocast():
            output = model(
                images,
                input_res=torch.ones(len(images)).float().to(images.device) * gsd_ratio,
            )
            print(output.shape)
            output = model.unpatchify(output,1)
            print(output.shape)
            print(target.shape)
            loss = criterion(output, target)

        #acc1, acc5 = accuracy(output, target, topk=(1, 5))

        batch_size = images.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.meters["BCE"].update(loss.item(), n=batch_size)
        #metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
        #metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
    # gather the stats from all processes
    # TODO: compute scores for segmentation at this step
    metric_logger.synchronize_between_processes()
    '''print(
        "* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}".format(
            top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss
        )'''
    print("* loss {losses.global_avg:.3f}".format(losses=metric_logger.loss)
    )

    #return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

In [34]:
args.eval = True

In [35]:
if args.eval:
    test_stats = evaluate(
        data_loader_val,
        model,
        device,
        eval_base_resolution=args.eval_base_resolution,
        gsd_embed=args.eval_gsd,
        eval_scale=args.eval_scale,
        reference_size=args.eval_reference_resolution,
    )
    print(
        f"BCE of the network on the test images: {test_stats['BCE']:.1f}%"
    )
    exit(0)

c


AttributeError: 'tuple' object has no attribute 'shape'

In [25]:
64*64*1024

4194304