In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
import os
import copy
import json
import pickle
import time
from datetime import timedelta

from utils import set_random, compute_entropy_v2
from training_utils import FocalLoss, validate_v2
from model import UNet
from model_v2 import deeplabv3_resnet50
from model_v3 import MC_DeepLab
from model_v4 import create_ViT_model
from create_dataset import (
    get_datasets,
    get_image_curr_labeled,
)
from routes import (
    CONTINUE_PATH,
    CONTINUE_FOLDER,
    PRINT_PATH,
    IGNORE_INDEX
)

#from other_queries import (
#    GT_query,
#    GTxSim_query,
#    density_entropy_query,
#    density_classEntropyV2_query,
#    density_query,
#    entropy_query,
#    BvSB_patch_query,
#    class_entropy_query,
#    class_entropy_patch_query,
#    class_entropy_ML_pred_patch_query,
#    class_entropy_video_query,
#    random_query,
#    coreset_query,
#    coreset_entropy_query,
#    COWAL_center_query,
#    k_means_entropy_query,
#    COWAL_entropy_query,
#    COWAL_entropy_video_query,
#    COWAL_entropy_patch_query,
#    MC_dropout_query,
#    BALD_query,
#    suggestive_annotation_query,
#   suggestive_annotation_patch_query,
#    VAAL_query,
#    oracle_query,
#    COWAL_classEntropy_query,
#    COWAL_classEntropy_video_query,
#    COWAL_classEntropy_v2_query,
#    COWAL_classEntropy_v2_video_query,
#    COWAL_classEntropy_v2_1_query,
#    COWAL_classEntropy_v2_2_query,
#    COWAL_classEntropy_v2_2_video_query,
#    COWAL_classEntropy_patch_query,
#    COWAL_classEntropy_patch_v2_query,
#    RIPU_PA_query,
#    entropy_patch_query,
#    revisiting_query,
#    revisiting_v2_query,
##    revisiting_adaptiveSP_query,
#    CBAL_query,
#    CBAL_v2_query,
#    pixelBal_query,
#    pixelBal_v2_query,
#    BvSB_patch_v2_query,
#)
from config import config

In [2]:
SEED = 0
set_random(SEED)

### SPLIT TRAIN SET ###
curr_labeled, curr_selected_patches, train_data, val_data, test_data, TRAIN_SEQ, VAL_SEQ, TEST_SEQ = get_image_curr_labeled(config, SEED, notebook=True)

################### define patience iter ###################
if config['N_LABEL'] > 1 and config['DATASET'] not in ['auris', 'intuitive', 'cityscapes', 'pascal_VOC'] and not config['BALANCED_INIT']:                                                 
    config['PATIENCE_ITER'] = 2 * (config['INIT_NUM_VIDEO'] + (config['NUM_ROUND'] - 1) * config['NUM_QUERY'])
# # elif config['N_LABEL'] > 1 and config['DATASET'] != 'auris' and config['BALANCED_INIT']:
#     config['PATIENCE_ITER'] = 2 * ((config['N_LABEL'] - 1) * 2 + (config['NUM_ROUND'] - 1) * config['NUM_QUERY'])
config["PATIENCE_ITER"] = config["PATIENCE_ITER"] // config["BATCH_SIZE"]

if config['DATASET'] == 'auris': # auris v11
    auris_patience_iter = [132, 90, 130, 130, 130, 127, 142, 113, 113, 113]
    config['PATIENCE_ITER'] = auris_patience_iter[SEED]
##################### 
    
start_round = 0
test_scores = []


set_random(SEED)  # to have the same model parameters

### SET UP DATASET ###
(
    train_dataloader,
    val_dataloader,
    test_dataloader,
    all_train_dataset,
    train_dataset_noAug,
    unlabeled_dataset,
) = get_datasets(
    config,
    curr_labeled,
    curr_selected_patches=curr_selected_patches,
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    TRAIN_SEQ=TRAIN_SEQ,
    VAL_SEQ=VAL_SEQ,
    TEST_SEQ=TEST_SEQ,
    notebook=True,
)
if config["NUM_ROUND"] == 1:
    config["PATIENCE_ITER"] = len(train_dataloader)


train video: ['01-00', '01-01', '01-03', '01-04', '01-07', '02-00', '02-01', '03-00', '03-01', '03-03', '03-04', '03-06', '03-07', '04-04', '04-07', '10-03', '11-01', '12-03', '13-02', '13-03', '15-00', '15-04', '15-05'], 23
val video: ['01-02', '01-06', '03-02', '03-05', '04-06', '08-02', '10-01', '13-00', '15-01', '15-02'], 10
test video: ['08-00', '08-01', '08-03', '08-04', '08-05', '08-06', '08-07', '08-08', '08-09', '09-00', '09-01', '09-02', '10-00', '10-02', '10-05', '10-06', '10-07', '11-00', '11-02', '11-03', '11-04', '13-04'], 22

sampling is: random
number of labeled train frames: 528 - 720 patches - 20.0 whole images
number of train frames: 1098 - total patches 39528
number of val frames: 629
number of test frames: 1148
number of unlabeled frames: 1098 - 38808 patches - 1078.0 whole images
first dataset sample: 15-00/frame4119.png



In [4]:
n_round = 0
model = deeplabv3_resnet50(
    pretrained=False, num_classes=config["N_LABEL"]
).to(config["DEVICE"])

copy_model = copy.deepcopy(model).to(config["DEVICE"])
copy_model.eval()
optimizer = optim.Adam(model.parameters(), lr=config["LEARNING_RATE"])
criterion = nn.BCEWithLogitsLoss() if config['N_LABEL'] == 1 else FocalLoss()

start_epoch = 0
best_val_score = 0
best_val_class_dices = {}
patience = 0
curr_iter = 0
nb_iter = 0

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


In [5]:
### TRAIN THE MODEL ###

image, mask, n = next(iter(train_dataloader))
image, mask = image.to(config["DEVICE"]), mask.to(config["DEVICE"])

In [6]:
image.shape

torch.Size([4, 3, 220, 220])

In [7]:
mask.shape

torch.Size([4, 220, 220])

In [8]:

# optimizer.zero_grad()
outputs = model(image)
loss = criterion(outputs, mask)

TypeError: __call__() takes 2 positional arguments but 3 were given

In [9]:
outputs.shape

torch.Size([4, 8, 220, 220])

In [10]:
mask.shape

torch.Size([4, 220, 220])