In [1]:
# !pip install wilds

Collecting wilds
  Downloading wilds-2.0.0-py3-none-any.whl (126 kB)
[K     |████████████████████████████████| 126 kB 5.3 MB/s 
[?25hCollecting pytz>=2020.4
  Downloading pytz-2022.1-py2.py3-none-any.whl (503 kB)
[K     |████████████████████████████████| 503 kB 44.7 MB/s 
[?25hCollecting scipy>=1.5.4
  Downloading scipy-1.7.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (38.1 MB)
[K     |████████████████████████████████| 38.1 MB 1.5 MB/s 
Collecting outdated>=0.2.0
  Downloading outdated-0.2.1-py3-none-any.whl (7.5 kB)
Collecting ogb>=1.2.6
  Downloading ogb-1.3.3-py3-none-any.whl (78 kB)
[K     |████████████████████████████████| 78 kB 6.2 MB/s 
Collecting pillow>=7.2.0
  Downloading Pillow-9.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.3 MB)
[K     |████████████████████████████████| 4.3 MB 34.2 MB/s 
Collecting littleutils
  Downloading littleutils-0.2.2.tar.gz (6.6 kB)
Building wheels for collected packages: littleutils
  Building wheel for lit

In [1]:
import numpy as np
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader
import torchvision.transforms as transforms
import torchvision.models as models
from wilds.common.data_loaders import get_eval_loader
import torch
from torch import nn
from tqdm import tqdm

In [2]:
# Load the full dataset, and download it if necessary
dataset = get_dataset(dataset="camelyon17", download=False)


In [3]:
BATCH_SIZE = 32
FRACTION = 0.33

# Get the training set
train_data = dataset.get_subset(
    "train",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(train_data)) #302436 initially
# Prepare the standard data loader
train_loader = get_train_loader("standard", train_data, batch_size=BATCH_SIZE)

"""
# (Optional) Load unlabeled data
dataset = get_dataset(dataset="camelyon17", download=True, unlabeled=True)
unlabeled_data = dataset.get_subset(
    "test_unlabeled",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)
unlabeled_loader = get_train_loader("standard", unlabeled_data, batch_size=16)
"""
"""
# Train loop
for labeled_batch, unlabeled_batch in zip(train_loader, unlabeled_loader):
    x, y, metadata = labeled_batch
    unlabeled_x, unlabeled_metadata = unlabeled_batch
    ...
"""

99804


'\n# Train loop\nfor labeled_batch, unlabeled_batch in zip(train_loader, unlabeled_loader):\n    x, y, metadata = labeled_batch\n    unlabeled_x, unlabeled_metadata = unlabeled_batch\n    ...\n'

In [4]:
# Get the test set
id_val_data = dataset.get_subset(
    "id_val",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(id_val_data))

# Prepare the evaluation data loader
id_val_loader = get_eval_loader("standard", id_val_data, batch_size=BATCH_SIZE)


11075


In [5]:
# Get the test set
val_data = dataset.get_subset(
    "val",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(val_data))

# Prepare the evaluation data loader
val_loader = get_eval_loader("standard", val_data, batch_size=BATCH_SIZE)

11518


In [6]:
# Get the test set
test_data = dataset.get_subset(
    "test",
    frac = FRACTION,
    transform=transforms.Compose(
        [
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])]
    ),
)

print(len(test_data))

# Prepare the evaluation data loader
test_loader = get_eval_loader("standard", test_data, batch_size=BATCH_SIZE)

28068


In [7]:

# load the ResNet-18 model, with weights pretrained on ImageNet
resnet18_pretrained = models.resnet18(pretrained=True)

num_params = 0
print("Model's parameters: ")
for n, p in resnet18_pretrained.named_parameters():
    print('\t', n, ': ', p.size())
    num_params += p.numel()
print("Number of model parameters: ", num_params)

"""
formula [(W−K+2P)/S]+1.

W is the input volume
K is the Kernel size
P is the padding
S is the stride

"""

Model's parameters: 
	 conv1.weight :  torch.Size([64, 3, 7, 7])
	 bn1.weight :  torch.Size([64])
	 bn1.bias :  torch.Size([64])
	 layer1.0.conv1.weight :  torch.Size([64, 64, 3, 3])
	 layer1.0.bn1.weight :  torch.Size([64])
	 layer1.0.bn1.bias :  torch.Size([64])
	 layer1.0.conv2.weight :  torch.Size([64, 64, 3, 3])
	 layer1.0.bn2.weight :  torch.Size([64])
	 layer1.0.bn2.bias :  torch.Size([64])
	 layer1.1.conv1.weight :  torch.Size([64, 64, 3, 3])
	 layer1.1.bn1.weight :  torch.Size([64])
	 layer1.1.bn1.bias :  torch.Size([64])
	 layer1.1.conv2.weight :  torch.Size([64, 64, 3, 3])
	 layer1.1.bn2.weight :  torch.Size([64])
	 layer1.1.bn2.bias :  torch.Size([64])
	 layer2.0.conv1.weight :  torch.Size([128, 64, 3, 3])
	 layer2.0.bn1.weight :  torch.Size([128])
	 layer2.0.bn1.bias :  torch.Size([128])
	 layer2.0.conv2.weight :  torch.Size([128, 128, 3, 3])
	 layer2.0.bn2.weight :  torch.Size([128])
	 layer2.0.bn2.bias :  torch.Size([128])
	 layer2.0.downsample.0.weight :  torch.Size([12

'\nformula [(W−K+2P)/S]+1.\n\nW is the input volume\nK is the Kernel size\nP is the padding\nS is the stride\n\n'

In [8]:
# function counting the number of parameters and the number of trainable parameters of a model
# optionally, it will also display the layers
def check_model_parameters(model, display_layers=False):
  num_params = 0
  num_trainable_params = 0
  if display_layers==True:
    print("Model's parameters: ")
  for n, p in model.named_parameters():
      if display_layers == True:
        print('\t', n, ': ', p.size())
      num_params += p.numel()
      if p.requires_grad:
        num_trainable_params += p.numel()
  print("Number of model parameters: ", num_params)
  print("Number of trainable parameters: ", num_trainable_params)

In [9]:
# freeze the model parameters

# check the number of parameters and the number of trainable parameters
check_model_parameters(resnet18_pretrained, display_layers=False)

# freeze all the layers
# for param in resnet18_pretrained.parameters():
  # param.requires_grad = False

# check the number of parameters and the number of trainable parameters
check_model_parameters(resnet18_pretrained, display_layers=False)

Number of model parameters:  11689512
Number of trainable parameters:  11689512
Number of model parameters:  11689512
Number of trainable parameters:  11689512


In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs = 5

In [11]:

resnet18_pretrained.fc = nn.Linear(in_features=512, out_features=2, bias=True)

check_model_parameters(resnet18_pretrained, display_layers=True)

Model's parameters: 
	 conv1.weight :  torch.Size([64, 3, 7, 7])
	 bn1.weight :  torch.Size([64])
	 bn1.bias :  torch.Size([64])
	 layer1.0.conv1.weight :  torch.Size([64, 64, 3, 3])
	 layer1.0.bn1.weight :  torch.Size([64])
	 layer1.0.bn1.bias :  torch.Size([64])
	 layer1.0.conv2.weight :  torch.Size([64, 64, 3, 3])
	 layer1.0.bn2.weight :  torch.Size([64])
	 layer1.0.bn2.bias :  torch.Size([64])
	 layer1.1.conv1.weight :  torch.Size([64, 64, 3, 3])
	 layer1.1.bn1.weight :  torch.Size([64])
	 layer1.1.bn1.bias :  torch.Size([64])
	 layer1.1.conv2.weight :  torch.Size([64, 64, 3, 3])
	 layer1.1.bn2.weight :  torch.Size([64])
	 layer1.1.bn2.bias :  torch.Size([64])
	 layer2.0.conv1.weight :  torch.Size([128, 64, 3, 3])
	 layer2.0.bn1.weight :  torch.Size([128])
	 layer2.0.bn1.bias :  torch.Size([128])
	 layer2.0.conv2.weight :  torch.Size([128, 128, 3, 3])
	 layer2.0.bn2.weight :  torch.Size([128])
	 layer2.0.bn2.bias :  torch.Size([128])
	 layer2.0.downsample.0.weight :  torch.Size([12

In [12]:
def train_epoch(model, train_dataloader, loss_crt, optimizer, device):
    """
    model: Model object
    train_dataloader: DataLoader over the training dataset
    loss_crt: loss function object
    optimizer: Optimizer object
    device: torch.device('cpu) or torch.device('cuda')

    The function returns:
     - the epoch training loss, which is an average over the individual batch
       losses
    """
    model.train()
    epoch_loss = 0.0
    epoch_accuracy = 0.0
    num_batches = len(train_dataloader)
    for batch_idx, batch in tqdm(enumerate(train_dataloader)):
        # shape: batch_size x 1 x 28 x 28, batch_size x 1
        # print("Train")
        # print(batch)
        batch_img, batch_labels, _ = batch
        # move data to GPU
        batch_img = batch_img.to(device)
        batch_labels = batch_labels.to(device)

        # initialize as zeros all the gradients of the model
        model.zero_grad()

        # get predictions from the FORWARD pass
        # shape: batch_size x 10
        output = model(batch_img)

        loss = loss_crt(output, batch_labels.squeeze())
        loss_scalar = loss.item()

        # BACKPROPAGATE the gradients
        loss.backward()
        # use the gradients to OPTIMISE the model
        optimizer.step()

        epoch_loss += loss_scalar

        pred = output.argmax(dim=1, keepdim=True)
        epoch_accuracy += pred.eq(batch_labels.view_as(pred)).float().mean().item()

    epoch_loss = epoch_loss/num_batches
    epoch_accuracy = 100. * epoch_accuracy/num_batches
    return epoch_loss, epoch_accuracy

def eval_epoch(model, val_dataloader, loss_crt, device):
    """
    model: Model object
    val_dataloader: DataLoader over the validation dataset
    loss_crt: loss function object
    device: torch.device('cpu) or torch.device('cuda')

    The function returns:
     - the epoch validation loss, which is an average over the individual batch
       losses
    """
    model.eval()
    epoch_loss = 0.0
    epoch_accuracy = 0.0
    num_batches = len(val_dataloader)
    with torch.no_grad():
        for batch_idx, batch in tqdm(enumerate(val_dataloader)):
            # print("Eval")
            # print(batch)
            # shape: batch_size x 3 x 28 x 28, batch_size x 1
            batch_img, batch_labels, _ = batch
            current_batch_size = batch_img.size(0)

            # move data to GPU
            batch_img = batch_img.to(device)
            batch_labels = batch_labels.to(device)

            # batch_size x 10
            output = model(batch_img)

            loss = loss_crt(output, batch_labels.squeeze())
            loss_scalar = loss.item()

            epoch_loss += loss_scalar

            pred = output.argmax(dim=1, keepdim=True)
            epoch_accuracy += pred.eq(batch_labels.view_as(pred)).float().mean().item()

    epoch_loss = epoch_loss/num_batches
    epoch_accuracy = 100. * epoch_accuracy/num_batches
    return epoch_loss, epoch_accuracy

In [13]:
resnet18_pretrained.to(device)

# create a SGD optimizer
optimizer = torch.optim.SGD(resnet18_pretrained.parameters(), lr=0.01, momentum=0.9)

# set up loss function
loss_criterion = nn.CrossEntropyLoss()

# evaluate the initial model
# val_loss, al_accuracy = eval_epoch(resnet18_pretrained, id_val_loader, loss_criterion, device)
# print('Validation performance before finetuning -- loss: %10.8f, accuracy: %10.8f'%(val_loss, val_accuracy))


In [None]:

# finetune the model
train_losses = []
train_accuracies = []
id_val_losses = []
id_val_accuracies = []
for epoch in range(1, num_epochs+1):
  train_loss, train_accuracy = train_epoch(resnet18_pretrained, train_loader, loss_criterion, optimizer, device)
  val_loss, val_accuracy = eval_epoch(resnet18_pretrained, id_val_loader, loss_criterion, device)
  train_losses.append(train_loss)
  id_val_losses.append(val_loss)
  train_accuracies.append(train_accuracy)
  id_val_accuracies.append(val_accuracy)
  print('\nEpoch %d'%(epoch))
  print('train loss: %10.8f, accuracy: %10.8f'%(train_loss, train_accuracy))
  print('id_val loss: %10.8f, accuracy: %10.8f'%(val_loss, val_accuracy))

In [None]:
def run_eval_ood(model, loader, loss_criterion, device, eval_type):
  losses = []
  accuracies = []

  loss, accuracy = eval_epoch(model, loader, loss_criterion, device)
  losses.append(loss)
  accuracies.append(accuracy)
  print(eval_type + ' loss: %10.8f, accuracy: %10.8f'%(loss, accuracy))

In [None]:
run_eval_ood(resnet18_pretrained, val_loader, loss_criterion, device, "val")

360it [00:43,  8.24it/s]

val loss: 0.46796423, accuracy: 87.63715279





In [None]:
run_eval_ood(resnet18_pretrained, test_loader, loss_criterion, device, "test")

878it [01:48,  8.10it/s]

test loss: 1.12270584, accuracy: 77.62314920





In [101]:
densenet121_pretrained = models.densenet121(pretrained=True)

# for param in densenet121_pretrained.parameters():
  # param.requires_grad = False

check_model_parameters(densenet121_pretrained, display_layers=False)


Number of model parameters:  7978856
Number of trainable parameters:  7978856


In [102]:

densenet121_pretrained.classifier = nn.Linear(in_features=1024, out_features=2, bias=True)
check_model_parameters(densenet121_pretrained, display_layers=False)

Number of model parameters:  6955906
Number of trainable parameters:  6955906


In [103]:
densenet121_pretrained.to(device)

# create a SGD optimizer
optimizer = torch.optim.SGD(densenet121_pretrained.parameters(), lr=0.01, momentum=0.9)

# set up loss function
loss_criterion = nn.CrossEntropyLoss()

In [None]:

# evaluate the initial model
# val_loss, al_accuracy = eval_epoch(resnet18_pretrained, id_val_loader, loss_criterion, device)
# print('Validation performance before finetuning -- loss: %10.8f, accuracy: %10.8f'%(val_loss, val_accuracy))

# finetune the model
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []
id_val_losses = []
id_val_accuracies = []

for epoch in range(1, num_epochs+1):
  train_loss, train_accuracy = train_epoch(densenet121_pretrained, train_loader, loss_criterion, optimizer, device)
  val_loss, val_accuracy = eval_epoch(densenet121_pretrained, id_val_loader, loss_criterion, device)
  train_losses.append(train_loss)
  id_val_losses.append(val_loss)
  train_accuracies.append(train_accuracy)
  id_val_accuracies.append(val_accuracy)
  print('\nEpoch %d'%(epoch))
  print('train loss: %10.8f, accuracy: %10.8f'%(train_loss, train_accuracy))
  print('id_val loss: %10.8f, accuracy: %10.8f'%(val_loss, val_accuracy))


In [None]:
run_eval_ood(densenet121_pretrained, val_loader, loss_criterion, device, "val")

360it [00:49,  7.26it/s]

val loss: 0.61660476, accuracy: 85.78125000





In [None]:
run_eval_ood(densenet121_pretrained, test_loader, loss_criterion, device, "test")

878it [02:00,  7.29it/s]

test loss: 0.91136280, accuracy: 79.53445330





In [None]:
torch.save(resnet18_pretrained.state_dict(), "resnet18_pretrained.pt")

In [None]:
torch.save(densenet121_pretrained.state_dict(), "densenet121_pretrained.pt")

In [99]:
resnetxt50_pretrained = models.resnext50_32x4d(pretrained=True)

# freeze all the layers
# for param in resnetxt50_pretrained.parameters():
  # param.requires_grad = False

resnetxt50_pretrained.fc = nn.Linear(in_features=2048, out_features=2, bias=True)

resnetxt50_pretrained.to(device)

# create a SGD optimizer
optimizer = torch.optim.SGD(resnetxt50_pretrained.parameters(), lr=0.01, momentum=0.9)

# set up loss function
loss_criterion = nn.CrossEntropyLoss()


In [100]:
# finetune the model
train_losses = []
train_accuracies = []
id_val_losses = []
id_val_accuracies = []
for epoch in range(1, num_epochs+1):
  train_loss, train_accuracy = train_epoch(resnetxt50_pretrained, train_loader, loss_criterion, optimizer, device)
  val_loss, val_accuracy = eval_epoch(resnetxt50_pretrained, id_val_loader, loss_criterion, device)
  train_losses.append(train_loss)
  id_val_losses.append(val_loss)
  train_accuracies.append(train_accuracy)
  id_val_accuracies.append(val_accuracy)
  print('\nEpoch %d'%(epoch))
  print('train loss: %10.8f, accuracy: %10.8f'%(train_loss, train_accuracy))
  print('id_val loss: %10.8f, accuracy: %10.8f'%(val_loss, val_accuracy))


435it [05:08,  1.41it/s]


KeyboardInterrupt: 

In [None]:
run_eval_ood(resnetxt50_pretrained, val_loader, loss_criterion, device, "val")

In [None]:
run_eval_ood(resnetxt50_pretrained, test_loader, loss_criterion, device, "test")

In [None]:
torch.save(resnetxt50_pretrained.state_dict(), "resnetxt50_pretrained.pt")

In [14]:
resnet18_pretrained.load_state_dict(torch.load("resnet18_pretrained_all_grad.pt"))

<All keys matched successfully>

In [53]:
"""
Rupere cap de clasificare si calculare distanta l2 si cos medie intre ood si iid
"""

# putem sa folosim labels atunci cand facem distanta? (i.e. distanta intra-clasa sau inter-clasa)

def calc_dist_iid_all(model, dataloader, metric="l2"):
    modules = list(model.children())[:-1]
    # print(modules)
    model = nn.Sequential(*modules)
    average_dist_iid = [[], []]
    model.eval()

    with torch.no_grad():
        for batch_idx, batch in tqdm(enumerate(dataloader)):
            batch_img, batch_labels, _ = batch

            batch_img = batch_img.to(device)
            batch_labels = batch_labels.to(device)

            output = model(batch_img)
            output = output.reshape(batch_img.shape[0], -1)
             #if batch_idx == 0:
                # print(output.shape)

            for idx_img_1 in range(batch_img.shape[0] - 1):
                for idx_img_2 in range(idx_img_1 + 1, batch_img.shape[0]):

                    if batch_labels[idx_img_1] == batch_labels[idx_img_2]:
                        if metric == "l2":
                            dist = torch.nn.functional.pairwise_distance(output[idx_img_1], output[idx_img_2])
                        elif metric == "cos":
                            cos = nn.CosineSimilarity(dim=1, eps=1e-6)
                            dist = cos(output[idx_img_1].reshape(1, -1), output[idx_img_2].reshape(1, -1))
                        elif metric == "mink":
                            dist = torch.cdist(output[idx_img_1].reshape(1, -1), output[idx_img_2].reshape(1, -1), p=3)
                        # dist = (output[idx_img_1] - output[idx_img_2]).pow(2).sum(0).sqrt()
                        average_dist_iid[batch_labels[idx_img_1]].append(dist.item())

    return average_dist_iid

def calc_dist_all(model, dataloader, x, y, metric="l2"): # x-> image to be decided if iid or ood, y-> label
    modules = list(model.children())[:-1]
    model = nn.Sequential(*modules)
    average_dist = 0
    count_labels = 0
    model.eval()
    num_batches = len(dataloader)
    # x, y, _ = img
    print(y)
    x = x.to(device)
    y = y.to(device)
    with torch.no_grad():
        x_shape = list(x.shape)
        x = x.reshape(1, x_shape[0], x_shape[1], x_shape[2])

        output_x = model(x)
        output_x = output_x.reshape(1, -1)
        for batch_idx, batch in tqdm(enumerate(dataloader)):
            batch_img, batch_labels, _ = batch

            batch_img = batch_img.to(device)
            batch_labels = batch_labels.to(device)

            output = model(batch_img)
            output = output.reshape(batch_img.shape[0], -1)
            # if batch_idx == 0:
                # print(output.shape)

            for idx_img_1 in range(batch_img.shape[0]):

                if batch_labels[idx_img_1] == y:

                    if metric == "l2":
                        dist = torch.nn.functional.pairwise_distance(output[idx_img_1], output_x)
                    elif metric == "cos":
                        cos = nn.CosineSimilarity(dim=1, eps=1e-6)
                        dist = cos(output[idx_img_1].reshape(1, -1), output_x.reshape(1, -1))
                    elif metric == "mink":
                        dist = torch.cdist(output[idx_img_1].reshape(1, -1), output_x.reshape(1, -1), p=3)
                    average_dist += dist.item()
                    count_labels += 1

    return average_dist / count_labels


In [15]:
import numpy as np
def calc_mean_std(arr):
    arr = np.asarray(arr)
    return arr.mean(), arr.std()


In [41]:
dist_iid = calc_dist_iid_all(resnet18_pretrained, id_val_loader)
print(calc_mean_std(dist_iid[0]))
print(calc_mean_std(dist_iid[1]))

347it [00:36,  9.55it/s]

(7.966576247082313, 4.114453328341953)
(49.825562248387, 43.42303952766147)





In [42]:
dist_iid_cos = calc_dist_iid_all(resnet18_pretrained, id_val_loader, "cos")
print(calc_mean_std(dist_iid_cos[0]))
print(calc_mean_std(dist_iid_cos[1]))

347it [00:51,  6.73it/s]

(0.7919799049652289, 0.1680152192808706)
(0.7590619617230417, 0.17710607193633654)





In [29]:
dist_l2_1_ood=[]
batch_data, batch_labels, _ = next(iter(val_loader))
for idx in range(len(batch_labels)):
    dist_l2_1_ood.append(calc_dist_all(resnet18_pretrained, id_val_loader, batch_data[idx], batch_labels[idx]))
print(dist_l2_1_ood)

tensor(1)


347it [02:26,  2.36it/s]


tensor(1)


347it [00:19, 17.44it/s]


tensor(1)


347it [00:20, 16.61it/s]


tensor(1)


347it [00:19, 17.43it/s]


tensor(1)


347it [00:19, 17.98it/s]


tensor(1)


347it [00:19, 17.84it/s]


tensor(1)


347it [00:19, 17.63it/s]


tensor(1)


347it [00:19, 17.89it/s]


tensor(1)


347it [00:19, 17.67it/s]


tensor(1)


347it [00:20, 17.10it/s]


tensor(1)


347it [00:21, 16.32it/s]


tensor(1)


347it [00:19, 17.70it/s]


tensor(1)


347it [00:19, 17.75it/s]


tensor(1)


347it [00:20, 16.93it/s]


tensor(1)


347it [00:20, 16.93it/s]


tensor(1)


347it [00:20, 17.06it/s]


tensor(1)


347it [00:20, 16.60it/s]


tensor(1)


347it [00:20, 16.82it/s]


tensor(1)


347it [00:20, 17.10it/s]


tensor(1)


347it [00:20, 16.81it/s]


tensor(1)


347it [00:20, 17.22it/s]


tensor(1)


347it [00:20, 16.87it/s]


tensor(1)


347it [00:21, 16.45it/s]


tensor(1)


347it [00:20, 16.95it/s]


tensor(1)


347it [00:20, 17.04it/s]


tensor(1)


347it [00:20, 16.59it/s]


tensor(1)


347it [00:20, 16.94it/s]


tensor(1)


347it [00:20, 16.85it/s]


tensor(1)


347it [00:20, 16.82it/s]


tensor(1)


347it [00:21, 16.51it/s]


tensor(1)


347it [00:20, 17.24it/s]


tensor(1)


347it [00:20, 17.09it/s]

[48.12580185434881, 47.122555061491404, 48.19732158671518, 48.86843590476756, 48.42316879562676, 47.934727400982375, 48.40140213387857, 48.42589491071027, 48.27477378906638, 48.52999814776862, 48.29174712005708, 47.49868544503276, 48.155919726056545, 47.84677312140183, 47.84191508635555, 47.56111748240753, 48.35410220488408, 48.08161313120373, 48.19475751959513, 48.26106034572395, 48.26563705590794, 47.97635042360272, 48.096771982324725, 48.392889665507454, 48.68382668782692, 48.70096763753377, 48.47466665370229, 48.63761319091284, 46.8111308097578, 48.53838990197721, 48.394299754054416, 48.298474198522065]





NameError: name 'calc_mean_std' is not defined

In [33]:
print(calc_mean_std(dist_l2_1_ood))


(48.176962147803266, 0.435619998337812)


In [34]:

for batch_idx, batch in tqdm(enumerate(val_loader)):
    if batch_idx > 200:
        batch_data, batch_labels, _ = batch
        break


201it [01:41,  1.99it/s]


In [35]:
dist_l2_0_ood=[]
for idx in range(len(batch_labels)):
    dist_l2_0_ood.append(calc_dist_all(resnet18_pretrained, id_val_loader, batch_data[idx], batch_labels[idx]))
print(dist_l2_0_ood)
print(calc_mean_std(dist_l2_0_ood))

tensor(0)


347it [01:38,  3.53it/s]


tensor(0)


347it [00:19, 17.87it/s]


tensor(0)


347it [00:19, 17.80it/s]


tensor(0)


347it [00:19, 17.83it/s]


tensor(0)


347it [00:19, 17.97it/s]


tensor(0)


347it [00:19, 17.39it/s]


tensor(0)


347it [00:19, 17.76it/s]


tensor(0)


347it [00:19, 17.59it/s]


tensor(0)


347it [00:20, 16.90it/s]


tensor(0)


347it [00:20, 16.88it/s]


tensor(0)


347it [00:20, 16.87it/s]


tensor(0)


347it [00:21, 16.44it/s]


tensor(0)


347it [00:21, 15.84it/s]


tensor(0)


347it [00:19, 17.57it/s]


tensor(0)


347it [00:20, 16.95it/s]


tensor(0)


347it [00:22, 15.48it/s]


tensor(0)


347it [00:20, 16.62it/s]


tensor(0)


347it [00:21, 16.25it/s]


tensor(0)


347it [00:20, 16.82it/s]


tensor(0)


347it [00:20, 16.89it/s]


tensor(0)


347it [00:19, 17.88it/s]


tensor(0)


347it [00:20, 16.91it/s]


tensor(0)


347it [00:19, 17.54it/s]


tensor(0)


347it [00:19, 17.45it/s]


tensor(0)


347it [00:19, 17.93it/s]


tensor(0)


347it [00:20, 16.64it/s]


tensor(0)


347it [00:20, 16.55it/s]


tensor(0)


347it [00:20, 17.10it/s]


tensor(0)


347it [00:20, 17.17it/s]


tensor(0)


347it [00:20, 17.29it/s]


tensor(0)


347it [00:20, 17.02it/s]


tensor(0)


347it [00:21, 16.45it/s]

[6.447804433305278, 7.410187580465462, 9.312605382544106, 9.999398875687643, 7.107455103621232, 6.62457378703924, 8.163959386986267, 9.453281946286097, 10.92258991177105, 6.886451105716525, 6.65852417166171, 7.121108471517348, 7.596846452209618, 10.719539246147166, 8.061968111302418, 7.656580279540352, 7.612348923062648, 6.898071432828647, 10.099604326586602, 6.972398030911118, 7.802922748255329, 6.866422055513252, 6.751424783661893, 7.658601330204207, 9.276851113886972, 8.483976501528598, 8.22611745022655, 8.345342162569604, 10.057817286399805, 11.670284460715335, 8.98549822165429, 8.240271701318202]
(8.252838336722643, 1.3907146492780709)





In [36]:
dist_cos_1_ood=[]
batch_data, batch_labels, _ = next(iter(val_loader))
for idx in range(len(batch_labels)):
    dist_cos_1_ood.append(calc_dist_all(resnet18_pretrained, id_val_loader, batch_data[idx], batch_labels[idx], "cos"))

print(dist_cos_1_ood)
print(calc_mean_std(dist_cos_1_ood))

tensor(1)


347it [00:21, 16.17it/s]


tensor(1)


347it [00:21, 16.16it/s]


tensor(1)


347it [00:21, 16.44it/s]


tensor(1)


347it [00:21, 16.43it/s]


tensor(1)


347it [00:21, 16.44it/s]


tensor(1)


347it [00:21, 16.47it/s]


tensor(1)


347it [00:20, 17.21it/s]


tensor(1)


347it [00:20, 17.17it/s]


tensor(1)


347it [00:20, 17.05it/s]


tensor(1)


347it [00:21, 15.93it/s]


tensor(1)


347it [00:21, 16.12it/s]


tensor(1)


347it [00:22, 15.72it/s]


tensor(1)


347it [00:20, 16.95it/s]


tensor(1)


347it [00:20, 17.17it/s]


tensor(1)


347it [00:20, 17.05it/s]


tensor(1)


347it [00:20, 17.31it/s]


tensor(1)


347it [00:21, 16.43it/s]


tensor(1)


347it [00:20, 16.98it/s]


tensor(1)


347it [00:21, 16.39it/s]


tensor(1)


347it [00:21, 16.49it/s]


tensor(1)


347it [00:21, 16.40it/s]


tensor(1)


347it [00:21, 16.38it/s]


tensor(1)


347it [00:20, 16.56it/s]


tensor(1)


347it [00:21, 15.90it/s]


tensor(1)


347it [00:23, 14.82it/s]


tensor(1)


347it [00:20, 16.62it/s]


tensor(1)


347it [00:21, 15.86it/s]


tensor(1)


347it [00:21, 15.98it/s]


tensor(1)


347it [00:21, 15.82it/s]


tensor(1)


347it [00:21, 15.88it/s]


tensor(1)


347it [00:22, 15.47it/s]


tensor(1)


347it [00:21, 16.02it/s]

[0.175232209817253, 0.4630427185848298, 0.170388258306152, 0.10917338993182181, 0.12266235429505674, 0.22345834297703016, 0.12188075855748703, 0.11155844406321272, 0.12773705047000353, 0.12533079190072502, 0.13358671296433516, 0.3935696503296209, 0.1600833570677566, 0.2524557155848813, 0.2446404269971491, 0.32720591324721354, 0.14727172043930725, 0.1804494578660177, 0.16489989133869928, 0.14942211344689757, 0.12534383641582791, 0.21077857022635704, 0.1809078685274981, 0.12980015278059942, 0.11318587095652129, 0.12595113761969443, 0.10525240928306291, 0.1376744835699961, 0.5055735935551493, 0.11792662925456687, 0.14499225003398764, 0.1414911873307145]
(0.18571647711685704, 0.09975466088155133)





In [37]:
for batch_idx, batch in tqdm(enumerate(val_loader)):
    if batch_idx > 200:
        batch_data, batch_labels, _ = batch
        break



201it [00:09, 22.10it/s]


In [38]:
dist_cos_0_ood=[]
for idx in range(len(batch_labels)):
    dist_cos_0_ood.append(calc_dist_all(resnet18_pretrained, id_val_loader, batch_data[idx], batch_labels[idx], "cos"))
print(dist_cos_0_ood)
print(calc_mean_std(dist_cos_0_ood))

tensor(0)


347it [00:21, 16.37it/s]


tensor(0)


347it [00:20, 16.68it/s]


tensor(0)


347it [00:20, 17.08it/s]


tensor(0)


347it [00:21, 16.35it/s]


tensor(0)


347it [00:21, 15.88it/s]


tensor(0)


347it [00:20, 16.58it/s]


tensor(0)


347it [00:22, 15.52it/s]


tensor(0)


347it [00:22, 15.75it/s]


tensor(0)


347it [00:21, 16.24it/s]


tensor(0)


347it [00:20, 16.94it/s]


tensor(0)


347it [00:20, 16.71it/s]


tensor(0)


347it [00:20, 16.54it/s]


tensor(0)


347it [00:21, 16.23it/s]


tensor(0)


347it [00:21, 15.82it/s]


tensor(0)


347it [00:20, 16.62it/s]


tensor(0)


347it [00:21, 15.98it/s]


tensor(0)


347it [00:20, 17.11it/s]


tensor(0)


347it [00:20, 17.04it/s]


tensor(0)


347it [00:21, 16.30it/s]


tensor(0)


347it [00:19, 17.70it/s]


tensor(0)


347it [00:21, 16.04it/s]


tensor(0)


347it [00:21, 16.19it/s]


tensor(0)


347it [00:20, 17.05it/s]


tensor(0)


347it [00:20, 16.85it/s]


tensor(0)


347it [00:21, 16.18it/s]


tensor(0)


347it [00:20, 16.56it/s]


tensor(0)


347it [00:20, 16.83it/s]


tensor(0)


347it [00:22, 15.55it/s]


tensor(0)


347it [00:20, 16.97it/s]


tensor(0)


347it [00:21, 16.41it/s]


tensor(0)


347it [00:22, 15.70it/s]


tensor(0)


347it [00:21, 15.86it/s]

[0.8398456293836027, 0.7464824036629562, 0.40257430049999404, 0.1786766498376539, 0.78359088537249, 0.8219335027535635, 0.8216371987043604, 0.36232815443091587, 0.27823858871997487, 0.8073127665302607, 0.8205473658002671, 0.7756444591836341, 0.7582207652783405, 0.8251768045367371, 0.7733613531406462, 0.7510258162328332, 0.7300574589508723, 0.7989621979373307, 0.18834901913755794, 0.8041467814542013, 0.7413368845991564, 0.8020367983521502, 0.8432519218317668, 0.8200783369959136, 0.4127714201589858, 0.6728487402321728, 0.6402545589860562, 0.6403657719226915, 0.263097822405491, 0.7834267529642515, 0.4978268022985809, 0.6911456807696396]
(0.6586422997832828, 0.2072243901638426)





In [54]:
dist_iid_mink = calc_dist_iid_all(resnet18_pretrained, id_val_loader, "mink")
print(calc_mean_std(dist_iid_mink[0]))
print(calc_mean_std(dist_iid_mink[1]))

"""
(7.966576247082313, 4.114453328341953)
(49.825562248387, 43.42303952766147)
"""


347it [00:33, 10.28it/s]

(4.242545318183882, 2.0646314975394597)
(19.471529235606923, 15.98058256400135)





'\n(7.966576247082313, 4.114453328341953)\n(49.825562248387, 43.42303952766147)\n'

In [55]:
"""Distanta minkowski"""
dist_mink_1_ood=[]
batch_data, batch_labels, _ = next(iter(val_loader))
for idx in range(len(batch_labels)):
    dist_mink_1_ood.append(calc_dist_all(resnet18_pretrained, id_val_loader, batch_data[idx], batch_labels[idx], "mink"))

print(dist_mink_1_ood)
print(calc_mean_std(dist_mink_1_ood))


tensor(1)


347it [00:20, 16.75it/s]


tensor(1)


347it [00:19, 17.55it/s]


tensor(1)


347it [00:21, 16.31it/s]


tensor(1)


347it [00:21, 16.27it/s]


tensor(1)


347it [00:21, 16.09it/s]


tensor(1)


347it [00:20, 16.94it/s]


tensor(1)


347it [00:20, 17.24it/s]


tensor(1)


347it [00:20, 17.25it/s]


tensor(1)


347it [00:22, 15.46it/s]


tensor(1)


347it [00:26, 13.27it/s]


tensor(1)


347it [00:19, 17.54it/s]


tensor(1)


347it [00:25, 13.80it/s]


tensor(1)


347it [00:34, 10.07it/s]


tensor(1)


347it [01:13,  4.70it/s]


tensor(1)


347it [00:21, 16.27it/s]


tensor(1)


347it [00:19, 17.54it/s]


tensor(1)


347it [00:19, 17.57it/s]


tensor(1)


347it [00:19, 17.44it/s]


tensor(1)


347it [00:20, 16.73it/s]


tensor(1)


347it [00:19, 17.41it/s]


tensor(1)


347it [00:20, 16.74it/s]


tensor(1)


347it [00:20, 16.81it/s]


tensor(1)


347it [00:20, 16.65it/s]


tensor(1)


347it [00:20, 16.73it/s]


tensor(1)


347it [00:20, 17.06it/s]


tensor(1)


347it [00:20, 16.94it/s]


tensor(1)


347it [00:20, 16.88it/s]


tensor(1)


347it [00:20, 17.08it/s]


tensor(1)


347it [00:20, 16.63it/s]


tensor(1)


347it [00:20, 16.69it/s]


tensor(1)


347it [00:20, 16.72it/s]


tensor(1)


347it [00:20, 16.69it/s]

[18.806861381425776, 18.428017747809502, 18.84163966417182, 19.13327356422054, 18.931674773571544, 18.72822936117464, 18.924689274611385, 18.92445374192953, 18.864682724112985, 18.975605630730875, 18.870566339177405, 18.56826925163358, 18.818919786833202, 18.70165938285986, 18.697113210815584, 18.586759681133604, 18.904315509091298, 18.789364965038676, 18.834630281620697, 18.86168492073423, 18.860953370570353, 18.750304582305002, 18.795131448590297, 18.919097869308647, 19.04862352648623, 19.062619922998813, 18.942149502691308, 19.042837906859592, 18.314613651035692, 18.980897942213172, 18.91981276443404, 18.881195099049286]
(18.834707774351223, 0.17335022120509108)





In [56]:
for batch_idx, batch in tqdm(enumerate(val_loader)):
    if batch_idx > 200:
        batch_data, batch_labels, _ = batch
        break

201it [00:13, 15.29it/s]


In [57]:
dist_mink_0_ood=[]
for idx in range(len(batch_labels)):
    dist_mink_0_ood.append(calc_dist_all(resnet18_pretrained, id_val_loader, batch_data[idx], batch_labels[idx], "mink"))
print(dist_mink_0_ood)
print(calc_mean_std(dist_mink_0_ood))

tensor(0)


347it [00:19, 17.46it/s]


tensor(0)


347it [00:20, 16.99it/s]


tensor(0)


347it [00:20, 16.71it/s]


tensor(0)


347it [00:22, 15.65it/s]


tensor(0)


347it [00:19, 17.60it/s]


tensor(0)


347it [00:19, 17.72it/s]


tensor(0)


347it [00:19, 17.79it/s]


tensor(0)


347it [00:19, 17.42it/s]


tensor(0)


347it [00:20, 17.28it/s]


tensor(0)


347it [00:20, 17.31it/s]


tensor(0)


347it [00:20, 17.32it/s]


tensor(0)


347it [00:20, 17.25it/s]


tensor(0)


347it [00:20, 17.18it/s]


tensor(0)


347it [00:19, 17.55it/s]


tensor(0)


347it [00:20, 16.87it/s]


tensor(0)


347it [00:20, 17.04it/s]


tensor(0)


347it [00:20, 16.83it/s]


tensor(0)


347it [00:20, 16.66it/s]


tensor(0)


347it [00:20, 16.89it/s]


tensor(0)


347it [00:20, 16.61it/s]


tensor(0)


347it [00:21, 15.81it/s]


tensor(0)


347it [00:21, 16.44it/s]


tensor(0)


347it [00:20, 17.03it/s]


tensor(0)


347it [00:20, 16.70it/s]


tensor(0)


347it [00:20, 16.56it/s]


tensor(0)


347it [00:21, 16.40it/s]


tensor(0)


347it [00:21, 16.49it/s]


tensor(0)


347it [00:20, 16.56it/s]


tensor(0)


347it [00:22, 15.26it/s]


tensor(0)


347it [00:21, 16.38it/s]


tensor(0)


347it [00:20, 16.65it/s]


tensor(0)


347it [00:20, 16.63it/s]

[3.4879563477093813, 3.9105290172994667, 4.952353565023015, 5.347901423321754, 3.8520147737457426, 3.6143249717515435, 4.319662275845475, 5.033975590386165, 5.222534278085171, 3.6938475488190647, 3.5639323552180002, 3.878303143508687, 4.128077111906854, 5.6040494969086065, 4.36397144165434, 4.172438359907464, 4.013442903001663, 3.7287927980573636, 5.346719126984802, 3.7637027302136468, 4.20871607063924, 3.765812632485144, 3.662454636745562, 4.0730222112065935, 4.972207514846465, 4.608387525147176, 4.348424479455788, 4.445709322093427, 5.1629670382354655, 6.187736644867445, 4.772453488845223, 4.4486147095607205]
(4.395469860421139, 0.6756717172554383)





In [86]:
def get_class_mean_embeddings(model, train_loader, nr_layers=-1):
    if nr_layers == -1:
        mean_embeds = np.zeros(shape=(2, 512)) # 2 classes, 512 embedding dim
    elif nr_layers == -2:
        mean_embeds = np.zeros(shape=(2, 512)) # 2 classes, 512 embedding dim no averagepool
    elif nr_layers == -3:
        mean_embeds = np.zeros(shape=(2, 256)) # 2 classes, 256 embedding dim
    count_per_class = np.zeros(shape=(2, 1))
    modules = list(model.children())[:nr_layers]
    print(list(model.children())[len(list(model.children()))-2])
    model = nn.Sequential(*modules, list(model.children())[len(list(model.children()))-2])
    print(model)
    model.eval()
    with torch.no_grad():
        for data in tqdm(train_loader):
            inputs, labels, _ = data
            inputs = inputs.to(device)
            labels = labels.numpy()

            # Extract features
            features = torch.flatten(model(inputs), start_dim=1)
            features = features.cpu().numpy()

            for idx in range(len(labels)):
                mean_embeds[labels[idx]] += features[idx]
                count_per_class[labels[idx]] += 1

    mean_embeds /= count_per_class

    return mean_embeds


In [29]:

mean_embeds = get_class_mean_embeddings(resnet18_pretrained, train_loader)
print(mean_embeds)

100%|██████████| 3119/3119 [26:45<00:00,  1.94it/s]

[[0.04605115 0.31298806 0.06170732 ... 0.54788485 0.00395388 0.08429653]
 [1.59088348 1.74999964 1.82361916 ... 0.72474282 2.19162934 1.59282019]]





In [92]:
def calc_dist_mean(model, dataloader, mean_embeds,  metric="l2", nr_layers=-1):
    modules = list(model.children())[:nr_layers]
    print(list(model.children())[len(list(model.children()))-2])
    model = nn.Sequential(*modules, list(model.children())[len(list(model.children()))-2])
    # print(model)
    average_dist = [[], []]
    model.eval()
    mean_embeds = torch.as_tensor(mean_embeds, dtype=torch.float, device=device)

    with torch.no_grad():
        for batch_idx, batch in tqdm(enumerate(dataloader)):
            batch_img, batch_labels, _ = batch

            batch_img = batch_img.to(device)
            batch_labels = batch_labels.to(device)

            output = model(batch_img)
            output = output.reshape(batch_img.shape[0], -1)
             #if batch_idx == 0:
                # print(output.shape)

            for idx_img_1 in range(batch_img.shape[0]):
                        if metric == "l2":
                            dist = torch.nn.functional.pairwise_distance(output[idx_img_1], mean_embeds[batch_labels[idx_img_1]])
                        elif metric == "cos":
                            cos = nn.CosineSimilarity(dim=1, eps=1e-6)
                            dist = cos(output[idx_img_1].reshape(1, -1), mean_embeds[batch_labels[idx_img_1]].reshape(1, -1))
                        elif metric == "mink":
                            dist = torch.cdist(output[idx_img_1].reshape(1, -1), mean_embeds[batch_labels[idx_img_1]].reshape(1, -1), p=3)
                        # dist = (output[idx_img_1] - output[idx_img_2]).pow(2).sum(0).sqrt()
                        average_dist[batch_labels[idx_img_1]].append(dist.item())

    return average_dist

In [30]:
dist_iid = calc_dist_mean(resnet18_pretrained, id_val_loader, mean_embeds)
print(calc_mean_std(dist_iid[0]))
print(calc_mean_std(dist_iid[1]))

360it [03:44,  1.61it/s]

(6.241036585918539, 2.3262729660942894)
(44.97533254815873, 35.394171235695964)





In [31]:
dist_ood = calc_dist_mean(resnet18_pretrained, val_loader, mean_embeds)
print(calc_mean_std(dist_ood[0]))
print(calc_mean_std(dist_ood[1]))

347it [04:11,  1.38it/s]

(6.09145959285451, 2.9619329014927733)
(42.740043913480655, 31.172719493451666)





In [32]:
dist_iid = calc_dist_mean(resnet18_pretrained, id_val_loader, mean_embeds, "cos")
print(calc_mean_std(dist_iid[0]))
print(calc_mean_std(dist_iid[1]))

360it [00:23, 15.60it/s]

(0.7371898718187957, 0.24255272251720775)
(0.7312124280774902, 0.19052081554650735)





In [33]:
dist_ood = calc_dist_mean(resnet18_pretrained, val_loader, mean_embeds, "cos")
print(calc_mean_std(dist_ood[0]))
print(calc_mean_std(dist_ood[1]))

347it [00:21, 16.29it/s]

(0.8683770586161912, 0.13710541809566884)
(0.782375466406486, 0.16523093317944965)





In [34]:
dist_iid = calc_dist_mean(resnet18_pretrained, id_val_loader, mean_embeds, "mink")
print(calc_mean_std(dist_iid[0]))
print(calc_mean_std(dist_iid[1]))

360it [00:21, 16.85it/s]

(3.2581733387925595, 1.1615214784287726)
(17.116536856941988, 13.627033010675087)





In [35]:
dist_ood = calc_dist_mean(resnet18_pretrained, val_loader, mean_embeds, "mink")
print(calc_mean_std(dist_ood[0]))
print(calc_mean_std(dist_ood[1]))

347it [00:19, 17.47it/s]

(3.2286353352855444, 1.5112794953031792)
(16.456411908988102, 11.836049372249718)





In [36]:
mean_embeds_val = get_class_mean_embeddings(resnet18_pretrained, val_loader)
print(mean_embeds_val)


100%|██████████| 360/360 [00:21<00:00, 17.03it/s]

[[0.03456208 0.05713736 0.05736421 ... 0.75617628 0.01469743 0.09006843]
 [0.93516183 1.42563253 1.64460315 ... 0.99811294 1.58906645 0.65909457]]





In [37]:
mean_embeds_id_val = get_class_mean_embeddings(resnet18_pretrained, id_val_loader)
print(mean_embeds_id_val)

100%|██████████| 347/347 [00:19<00:00, 18.02it/s]

[[0.04387564 0.31455867 0.06047065 ... 0.54764414 0.00450198 0.08598702]
 [1.63332741 1.80945475 1.8598073  ... 0.72630386 2.2534007  1.65054148]]





In [39]:
mean_embeds = torch.as_tensor(mean_embeds, dtype=torch.float, device=device)
mean_embeds_val = torch.as_tensor(mean_embeds_val, dtype=torch.float, device=device)
mean_embeds_id_val = torch.as_tensor(mean_embeds_id_val, dtype=torch.float, device=device)


In [40]:
print("l2 dist between embeds mean:")
print(torch.nn.functional.pairwise_distance(mean_embeds_val[0], mean_embeds[0]))
print(torch.nn.functional.pairwise_distance(mean_embeds_val[1], mean_embeds[1]))


print(torch.nn.functional.pairwise_distance(mean_embeds_id_val[0], mean_embeds[0]))
print(torch.nn.functional.pairwise_distance(mean_embeds_id_val[1], mean_embeds[1]))

tensor(4.3806, device='cuda:0')
tensor(11.3567, device='cuda:0')
tensor(0.1037, device='cuda:0')
tensor(0.8313, device='cuda:0')


In [41]:
print("cos sim between embeds mean:")
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
print(cos(mean_embeds_val[0].reshape(1, -1), mean_embeds[0].reshape(1, -1)))
print(cos(mean_embeds_val[1].reshape(1, -1), mean_embeds[1].reshape(1, -1)))

print(cos(mean_embeds_id_val[0].reshape(1, -1), mean_embeds[0].reshape(1, -1)))
print(cos(mean_embeds_id_val[1].reshape(1, -1), mean_embeds[1].reshape(1, -1)))


tensor([0.9498], device='cuda:0')
tensor([0.9679], device='cuda:0')
tensor([0.9999], device='cuda:0')
tensor([1.0000], device='cuda:0')


In [42]:
print("minkowski dist between embeds mean:")
print(torch.cdist(mean_embeds_val[0].reshape(1, -1), mean_embeds[0].reshape(1, -1), p=3))
print(torch.cdist(mean_embeds_val[1].reshape(1, -1), mean_embeds[1].reshape(1, -1), p=3))

print(torch.cdist(mean_embeds_id_val[0].reshape(1, -1), mean_embeds[0].reshape(1, -1), p=3))
print(torch.cdist(mean_embeds_id_val[1].reshape(1, -1), mean_embeds[1].reshape(1, -1), p=3))

tensor([[2.2981]], device='cuda:0')
tensor([[4.6342]], device='cuda:0')
tensor([[0.0534]], device='cuda:0')
tensor([[0.3144]], device='cuda:0')


In [88]:
mean_embeds = get_class_mean_embeddings(resnet18_pretrained, train_loader, nr_layers=-3)
print(mean_embeds)

AdaptiveAvgPool2d(output_size=(1, 1))
Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu):

100%|██████████| 3119/3119 [37:35<00:00,  1.38it/s]

[[9.62999058e-02 4.25739106e-02 2.86310611e-01 5.10014238e-02
  1.16603529e-01 1.06310124e-01 1.11294071e-01 1.08416164e-01
  6.86649944e-02 1.57307637e-02 4.01023564e-02 5.04114060e-02
  1.23886728e-01 1.12970853e-01 2.97541144e-02 3.18589409e-01
  1.34469019e-02 8.99669534e-02 1.23463651e-01 7.68679635e-02
  1.70015454e-01 2.14265131e-01 2.56818615e-01 1.31065622e-02
  2.53554544e-02 5.90108821e-02 1.43084266e-01 1.66752350e-01
  2.42696018e-01 9.17266546e-02 5.38571934e-02 7.14162661e-02
  4.70630969e-02 2.12123266e-01 1.22627422e-02 1.94444225e-02
  8.29160436e-02 3.37345507e-01 1.54171499e-01 4.54884037e-02
  3.33842238e-02 7.14162418e-02 2.98005412e-01 7.50812515e-03
  1.17932850e-01 3.55113080e-01 2.46073830e-01 5.20211946e-02
  4.08874140e-02 6.88235713e-02 9.48367416e-02 8.35003911e-02
  6.66227730e-01 4.09072305e-01 2.05891193e-01 8.61146661e-02
  1.42799703e-01 1.00277452e-01 1.42341525e-01 1.66415261e-01
  1.25623591e-01 1.05515646e-01 1.38512378e-03 5.95332861e-02
  1.2190




In [89]:
dist_iid = calc_dist_mean(resnet18_pretrained, id_val_loader, mean_embeds, nr_layers=-3)
print(calc_mean_std(dist_iid[0]))
print(calc_mean_std(dist_iid[1]))
dist_ood = calc_dist_mean(resnet18_pretrained, val_loader, mean_embeds, nr_layers=-3)
print(calc_mean_std(dist_ood[0]))
print(calc_mean_std(dist_ood[1]))

AdaptiveAvgPool2d(output_size=(1, 1))
Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu):

360it [03:30,  1.71it/s]


(1.4623581341585625, 0.3112988782412788)
(1.9183645550376769, 0.42766292546987833)
AdaptiveAvgPool2d(output_size=(1, 1))
Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2

347it [04:14,  1.36it/s]

(1.4383305217823106, 0.3185095421168888)
(1.5216503871895912, 0.3587628149208194)





In [90]:
dist_iid = calc_dist_mean(resnet18_pretrained, id_val_loader, mean_embeds, "cos", nr_layers=-3)
print(calc_mean_std(dist_iid[0]))
print(calc_mean_std(dist_iid[1]))
dist_ood = calc_dist_mean(resnet18_pretrained, val_loader, mean_embeds, "cos", nr_layers=-3)
print(calc_mean_std(dist_ood[0]))
print(calc_mean_std(dist_ood[1]))

AdaptiveAvgPool2d(output_size=(1, 1))
Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu):

360it [00:18, 19.65it/s]


(0.8307092270363555, 0.06036752335644357)
(0.8285226818010802, 0.07388714336314596)
AdaptiveAvgPool2d(output_size=(1, 1))
Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm

347it [00:17, 19.72it/s]

(0.8518162364909172, 0.04572092443268793)
(0.8804776922211843, 0.06308674302825565)





In [91]:
dist_iid = calc_dist_mean(resnet18_pretrained, id_val_loader, mean_embeds, "mink", nr_layers=-3)
print(calc_mean_std(dist_iid[0]))
print(calc_mean_std(dist_iid[1]))
dist_ood = calc_dist_mean(resnet18_pretrained, val_loader, mean_embeds, "mink", nr_layers=-3)
print(calc_mean_std(dist_ood[0]))
print(calc_mean_std(dist_ood[1]))

AdaptiveAvgPool2d(output_size=(1, 1))
Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu):

360it [00:18, 19.65it/s]


(0.7554970607084046, 0.17429233040889258)
(0.9869966557769204, 0.24786116967803504)
AdaptiveAvgPool2d(output_size=(1, 1))
Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm

347it [00:17, 19.74it/s]

(0.7662587817135879, 0.19185095780040134)
(0.7675117084661485, 0.1974272505649154)





In [93]:
mean_embeds_val = get_class_mean_embeddings(resnet18_pretrained, val_loader, nr_layers=-3)
print(mean_embeds_val)

AdaptiveAvgPool2d(output_size=(1, 1))
Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu):

100%|██████████| 360/360 [00:20<00:00, 17.86it/s]

[[1.03079025e-01 9.32554994e-02 2.91884415e-01 6.32417167e-02
  2.44392870e-01 1.63989274e-01 7.45648658e-02 1.60837345e-01
  1.42723737e-01 3.74095972e-02 4.34687676e-02 6.22680398e-02
  1.57945948e-01 1.21545605e-01 3.94698559e-02 2.15939655e-01
  3.02931974e-02 1.07882529e-01 1.19910823e-01 1.23762847e-01
  1.30726576e-01 2.10457428e-01 2.29852511e-01 2.72941479e-02
  5.26532178e-02 6.52813011e-02 7.80750697e-02 1.99874614e-01
  2.42577267e-01 1.60588226e-01 5.07706562e-02 6.89630474e-02
  9.64239377e-02 1.70798941e-01 2.06209612e-02 8.23107012e-03
  3.78978362e-02 1.59796158e-01 1.79011868e-01 1.24405348e-02
  5.65703072e-02 6.17422383e-02 1.88731858e-01 3.04911523e-02
  2.11799150e-02 4.54004177e-01 2.11178783e-01 5.75053609e-02
  8.24769987e-02 7.24392854e-02 6.37014055e-02 1.34808432e-01
  5.79417419e-01 3.42582920e-01 6.42187290e-02 6.41060735e-02
  2.18138345e-01 9.90464082e-02 1.05723625e-01 1.56220013e-01
  1.68105941e-01 1.75770933e-01 2.54565412e-03 4.70536812e-02
  1.2322




In [94]:
mean_embeds_id_val = get_class_mean_embeddings(resnet18_pretrained, id_val_loader, nr_layers=-3)
print(mean_embeds_id_val)

AdaptiveAvgPool2d(output_size=(1, 1))
Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu):

100%|██████████| 347/347 [00:16<00:00, 21.17it/s]

[[9.44398790e-02 4.17343035e-02 2.78925616e-01 5.39471499e-02
  1.15424816e-01 1.06197999e-01 1.14189213e-01 1.06198877e-01
  6.73954716e-02 1.69919183e-02 4.22502874e-02 5.01476592e-02
  1.24029062e-01 1.12372566e-01 2.92677864e-02 3.24198352e-01
  1.43453335e-02 8.88023911e-02 1.21743495e-01 7.90776454e-02
  1.70713969e-01 2.08803934e-01 2.57447182e-01 1.34650834e-02
  2.55587766e-02 5.95566567e-02 1.41423063e-01 1.69574197e-01
  2.41196247e-01 9.07171202e-02 5.36562166e-02 7.11065901e-02
  4.76083989e-02 2.15842112e-01 1.28556675e-02 1.98849042e-02
  8.08157940e-02 3.37793723e-01 1.56171699e-01 4.44945883e-02
  3.33909358e-02 7.10355442e-02 2.99450681e-01 8.01538346e-03
  1.16440136e-01 3.56537103e-01 2.43659484e-01 5.31006530e-02
  4.11474557e-02 7.07121219e-02 9.56247290e-02 8.34229393e-02
  6.69527951e-01 4.09844980e-01 2.03059547e-01 8.62793904e-02
  1.41875425e-01 9.84102231e-02 1.38465702e-01 1.66108513e-01
  1.28087475e-01 1.10359288e-01 1.38636928e-03 6.18836602e-02
  1.1995




In [95]:
mean_embeds = torch.as_tensor(mean_embeds, dtype=torch.float, device=device)
mean_embeds_val = torch.as_tensor(mean_embeds_val, dtype=torch.float, device=device)
mean_embeds_id_val = torch.as_tensor(mean_embeds_id_val, dtype=torch.float, device=device)

In [96]:
print("l2 dist between embeds mean:")
print(torch.nn.functional.pairwise_distance(mean_embeds_val[0], mean_embeds[0]))
print(torch.nn.functional.pairwise_distance(mean_embeds_val[1], mean_embeds[1]))

print(torch.nn.functional.pairwise_distance(mean_embeds_id_val[0], mean_embeds[0]))
print(torch.nn.functional.pairwise_distance(mean_embeds_id_val[1], mean_embeds[1]))
print("cos sim between embeds mean:")
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
print(cos(mean_embeds_val[0].reshape(1, -1), mean_embeds[0].reshape(1, -1)))
print(cos(mean_embeds_val[1].reshape(1, -1), mean_embeds[1].reshape(1, -1)))

print(cos(mean_embeds_id_val[0].reshape(1, -1), mean_embeds[0].reshape(1, -1)))
print(cos(mean_embeds_id_val[1].reshape(1, -1), mean_embeds[1].reshape(1, -1)))

print("minkowski dist between embeds mean:")
print(torch.cdist(mean_embeds_val[0].reshape(1, -1), mean_embeds[0].reshape(1, -1), p=3))
print(torch.cdist(mean_embeds_val[1].reshape(1, -1), mean_embeds[1].reshape(1, -1), p=3))

print(torch.cdist(mean_embeds_id_val[0].reshape(1, -1), mean_embeds[0].reshape(1, -1), p=3))
print(torch.cdist(mean_embeds_id_val[1].reshape(1, -1), mean_embeds[1].reshape(1, -1), p=3))

l2 dist between embeds mean:
tensor(0.7717, device='cuda:0')
tensor(1.0831, device='cuda:0')
tensor(0.0298, device='cuda:0')
tensor(0.0180, device='cuda:0')
cos sim between embeds mean:
tensor([0.9444], device='cuda:0')
tensor([0.9329], device='cuda:0')
tensor([0.9999], device='cuda:0')
tensor([1.0000], device='cuda:0')
minkowski dist between embeds mean:
tensor([[0.3909]], device='cuda:0')
tensor([[0.5523]], device='cuda:0')
tensor([[0.0154]], device='cuda:0')
tensor([[0.0088]], device='cuda:0')


In [98]:
"""
resnext
"""
resnetxt50_pretrained.load_state_dict(torch.load("resnetxt50_pretrained_all_grad.pt"))


RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.conv3.weight", "layer2.0.bn3.weight", "layer2.0.bn3.bias", "layer2.0.bn3.running_mean", "layer2.0.bn3.running_var", "layer2.0.downsample.0.weight", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.1.conv1.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.conv2.weight", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.1.conv3.weight", "layer2.1.bn3.weight", "layer2.1.bn3.bias", "layer2.1.bn3.running_mean", "layer2.1.bn3.running_var", "layer2.2.conv1.weight", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.conv2.weight", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.2.conv3.weight", "layer2.2.bn3.weight", "layer2.2.bn3.bias", "layer2.2.bn3.running_mean", "layer2.2.bn3.running_var", "layer2.3.conv1.weight", "layer2.3.bn1.weight", "layer2.3.bn1.bias", "layer2.3.bn1.running_mean", "layer2.3.bn1.running_var", "layer2.3.conv2.weight", "layer2.3.bn2.weight", "layer2.3.bn2.bias", "layer2.3.bn2.running_mean", "layer2.3.bn2.running_var", "layer2.3.conv3.weight", "layer2.3.bn3.weight", "layer2.3.bn3.bias", "layer2.3.bn3.running_mean", "layer2.3.bn3.running_var", "layer3.0.conv1.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.conv2.weight", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.conv3.weight", "layer3.0.bn3.weight", "layer3.0.bn3.bias", "layer3.0.bn3.running_mean", "layer3.0.bn3.running_var", "layer3.0.downsample.0.weight", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "layer3.1.conv1.weight", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.conv2.weight", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.1.conv3.weight", "layer3.1.bn3.weight", "layer3.1.bn3.bias", "layer3.1.bn3.running_mean", "layer3.1.bn3.running_var", "layer3.2.conv1.weight", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.conv2.weight", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.2.conv3.weight", "layer3.2.bn3.weight", "layer3.2.bn3.bias", "layer3.2.bn3.running_mean", "layer3.2.bn3.running_var", "layer3.3.conv1.weight", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.conv2.weight", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.3.conv3.weight", "layer3.3.bn3.weight", "layer3.3.bn3.bias", "layer3.3.bn3.running_mean", "layer3.3.bn3.running_var", "layer3.4.conv1.weight", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.conv2.weight", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.4.conv3.weight", "layer3.4.bn3.weight", "layer3.4.bn3.bias", "layer3.4.bn3.running_mean", "layer3.4.bn3.running_var", "layer3.5.conv1.weight", "layer3.5.bn1.weight", "layer3.5.bn1.bias", "layer3.5.bn1.running_mean", "layer3.5.bn1.running_var", "layer3.5.conv2.weight", "layer3.5.bn2.weight", "layer3.5.bn2.bias", "layer3.5.bn2.running_mean", "layer3.5.bn2.running_var", "layer3.5.conv3.weight", "layer3.5.bn3.weight", "layer3.5.bn3.bias", "layer3.5.bn3.running_mean", "layer3.5.bn3.running_var", "layer4.0.conv1.weight", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.conv2.weight", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.conv3.weight", "layer4.0.bn3.weight", "layer4.0.bn3.bias", "layer4.0.bn3.running_mean", "layer4.0.bn3.running_var", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.1.conv1.weight", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.conv2.weight", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.conv3.weight", "layer4.1.bn3.weight", "layer4.1.bn3.bias", "layer4.1.bn3.running_mean", "layer4.1.bn3.running_var", "layer4.2.conv1.weight", "layer4.2.bn1.weight", "layer4.2.bn1.bias", "layer4.2.bn1.running_mean", "layer4.2.bn1.running_var", "layer4.2.conv2.weight", "layer4.2.bn2.weight", "layer4.2.bn2.bias", "layer4.2.bn2.running_mean", "layer4.2.bn2.running_var", "layer4.2.conv3.weight", "layer4.2.bn3.weight", "layer4.2.bn3.bias", "layer4.2.bn3.running_mean", "layer4.2.bn3.running_var", "fc.weight", "fc.bias". 
	Unexpected key(s) in state_dict: "features.conv0.weight", "features.norm0.weight", "features.norm0.bias", "features.norm0.running_mean", "features.norm0.running_var", "features.norm0.num_batches_tracked", "features.denseblock1.denselayer1.norm1.weight", "features.denseblock1.denselayer1.norm1.bias", "features.denseblock1.denselayer1.norm1.running_mean", "features.denseblock1.denselayer1.norm1.running_var", "features.denseblock1.denselayer1.norm1.num_batches_tracked", "features.denseblock1.denselayer1.conv1.weight", "features.denseblock1.denselayer1.norm2.weight", "features.denseblock1.denselayer1.norm2.bias", "features.denseblock1.denselayer1.norm2.running_mean", "features.denseblock1.denselayer1.norm2.running_var", "features.denseblock1.denselayer1.norm2.num_batches_tracked", "features.denseblock1.denselayer1.conv2.weight", "features.denseblock1.denselayer2.norm1.weight", "features.denseblock1.denselayer2.norm1.bias", "features.denseblock1.denselayer2.norm1.running_mean", "features.denseblock1.denselayer2.norm1.running_var", "features.denseblock1.denselayer2.norm1.num_batches_tracked", "features.denseblock1.denselayer2.conv1.weight", "features.denseblock1.denselayer2.norm2.weight", "features.denseblock1.denselayer2.norm2.bias", "features.denseblock1.denselayer2.norm2.running_mean", "features.denseblock1.denselayer2.norm2.running_var", "features.denseblock1.denselayer2.norm2.num_batches_tracked", "features.denseblock1.denselayer2.conv2.weight", "features.denseblock1.denselayer3.norm1.weight", "features.denseblock1.denselayer3.norm1.bias", "features.denseblock1.denselayer3.norm1.running_mean", "features.denseblock1.denselayer3.norm1.running_var", "features.denseblock1.denselayer3.norm1.num_batches_tracked", "features.denseblock1.denselayer3.conv1.weight", "features.denseblock1.denselayer3.norm2.weight", "features.denseblock1.denselayer3.norm2.bias", "features.denseblock1.denselayer3.norm2.running_mean", "features.denseblock1.denselayer3.norm2.running_var", "features.denseblock1.denselayer3.norm2.num_batches_tracked", "features.denseblock1.denselayer3.conv2.weight", "features.denseblock1.denselayer4.norm1.weight", "features.denseblock1.denselayer4.norm1.bias", "features.denseblock1.denselayer4.norm1.running_mean", "features.denseblock1.denselayer4.norm1.running_var", "features.denseblock1.denselayer4.norm1.num_batches_tracked", "features.denseblock1.denselayer4.conv1.weight", "features.denseblock1.denselayer4.norm2.weight", "features.denseblock1.denselayer4.norm2.bias", "features.denseblock1.denselayer4.norm2.running_mean", "features.denseblock1.denselayer4.norm2.running_var", "features.denseblock1.denselayer4.norm2.num_batches_tracked", "features.denseblock1.denselayer4.conv2.weight", "features.denseblock1.denselayer5.norm1.weight", "features.denseblock1.denselayer5.norm1.bias", "features.denseblock1.denselayer5.norm1.running_mean", "features.denseblock1.denselayer5.norm1.running_var", "features.denseblock1.denselayer5.norm1.num_batches_tracked", "features.denseblock1.denselayer5.conv1.weight", "features.denseblock1.denselayer5.norm2.weight", "features.denseblock1.denselayer5.norm2.bias", "features.denseblock1.denselayer5.norm2.running_mean", "features.denseblock1.denselayer5.norm2.running_var", "features.denseblock1.denselayer5.norm2.num_batches_tracked", "features.denseblock1.denselayer5.conv2.weight", "features.denseblock1.denselayer6.norm1.weight", "features.denseblock1.denselayer6.norm1.bias", "features.denseblock1.denselayer6.norm1.running_mean", "features.denseblock1.denselayer6.norm1.running_var", "features.denseblock1.denselayer6.norm1.num_batches_tracked", "features.denseblock1.denselayer6.conv1.weight", "features.denseblock1.denselayer6.norm2.weight", "features.denseblock1.denselayer6.norm2.bias", "features.denseblock1.denselayer6.norm2.running_mean", "features.denseblock1.denselayer6.norm2.running_var", "features.denseblock1.denselayer6.norm2.num_batches_tracked", "features.denseblock1.denselayer6.conv2.weight", "features.transition1.norm.weight", "features.transition1.norm.bias", "features.transition1.norm.running_mean", "features.transition1.norm.running_var", "features.transition1.norm.num_batches_tracked", "features.transition1.conv.weight", "features.denseblock2.denselayer1.norm1.weight", "features.denseblock2.denselayer1.norm1.bias", "features.denseblock2.denselayer1.norm1.running_mean", "features.denseblock2.denselayer1.norm1.running_var", "features.denseblock2.denselayer1.norm1.num_batches_tracked", "features.denseblock2.denselayer1.conv1.weight", "features.denseblock2.denselayer1.norm2.weight", "features.denseblock2.denselayer1.norm2.bias", "features.denseblock2.denselayer1.norm2.running_mean", "features.denseblock2.denselayer1.norm2.running_var", "features.denseblock2.denselayer1.norm2.num_batches_tracked", "features.denseblock2.denselayer1.conv2.weight", "features.denseblock2.denselayer2.norm1.weight", "features.denseblock2.denselayer2.norm1.bias", "features.denseblock2.denselayer2.norm1.running_mean", "features.denseblock2.denselayer2.norm1.running_var", "features.denseblock2.denselayer2.norm1.num_batches_tracked", "features.denseblock2.denselayer2.conv1.weight", "features.denseblock2.denselayer2.norm2.weight", "features.denseblock2.denselayer2.norm2.bias", "features.denseblock2.denselayer2.norm2.running_mean", "features.denseblock2.denselayer2.norm2.running_var", "features.denseblock2.denselayer2.norm2.num_batches_tracked", "features.denseblock2.denselayer2.conv2.weight", "features.denseblock2.denselayer3.norm1.weight", "features.denseblock2.denselayer3.norm1.bias", "features.denseblock2.denselayer3.norm1.running_mean", "features.denseblock2.denselayer3.norm1.running_var", "features.denseblock2.denselayer3.norm1.num_batches_tracked", "features.denseblock2.denselayer3.conv1.weight", "features.denseblock2.denselayer3.norm2.weight", "features.denseblock2.denselayer3.norm2.bias", "features.denseblock2.denselayer3.norm2.running_mean", "features.denseblock2.denselayer3.norm2.running_var", "features.denseblock2.denselayer3.norm2.num_batches_tracked", "features.denseblock2.denselayer3.conv2.weight", "features.denseblock2.denselayer4.norm1.weight", "features.denseblock2.denselayer4.norm1.bias", "features.denseblock2.denselayer4.norm1.running_mean", "features.denseblock2.denselayer4.norm1.running_var", "features.denseblock2.denselayer4.norm1.num_batches_tracked", "features.denseblock2.denselayer4.conv1.weight", "features.denseblock2.denselayer4.norm2.weight", "features.denseblock2.denselayer4.norm2.bias", "features.denseblock2.denselayer4.norm2.running_mean", "features.denseblock2.denselayer4.norm2.running_var", "features.denseblock2.denselayer4.norm2.num_batches_tracked", "features.denseblock2.denselayer4.conv2.weight", "features.denseblock2.denselayer5.norm1.weight", "features.denseblock2.denselayer5.norm1.bias", "features.denseblock2.denselayer5.norm1.running_mean", "features.denseblock2.denselayer5.norm1.running_var", "features.denseblock2.denselayer5.norm1.num_batches_tracked", "features.denseblock2.denselayer5.conv1.weight", "features.denseblock2.denselayer5.norm2.weight", "features.denseblock2.denselayer5.norm2.bias", "features.denseblock2.denselayer5.norm2.running_mean", "features.denseblock2.denselayer5.norm2.running_var", "features.denseblock2.denselayer5.norm2.num_batches_tracked", "features.denseblock2.denselayer5.conv2.weight", "features.denseblock2.denselayer6.norm1.weight", "features.denseblock2.denselayer6.norm1.bias", "features.denseblock2.denselayer6.norm1.running_mean", "features.denseblock2.denselayer6.norm1.running_var", "features.denseblock2.denselayer6.norm1.num_batches_tracked", "features.denseblock2.denselayer6.conv1.weight", "features.denseblock2.denselayer6.norm2.weight", "features.denseblock2.denselayer6.norm2.bias", "features.denseblock2.denselayer6.norm2.running_mean", "features.denseblock2.denselayer6.norm2.running_var", "features.denseblock2.denselayer6.norm2.num_batches_tracked", "features.denseblock2.denselayer6.conv2.weight", "features.denseblock2.denselayer7.norm1.weight", "features.denseblock2.denselayer7.norm1.bias", "features.denseblock2.denselayer7.norm1.running_mean", "features.denseblock2.denselayer7.norm1.running_var", "features.denseblock2.denselayer7.norm1.num_batches_tracked", "features.denseblock2.denselayer7.conv1.weight", "features.denseblock2.denselayer7.norm2.weight", "features.denseblock2.denselayer7.norm2.bias", "features.denseblock2.denselayer7.norm2.running_mean", "features.denseblock2.denselayer7.norm2.running_var", "features.denseblock2.denselayer7.norm2.num_batches_tracked", "features.denseblock2.denselayer7.conv2.weight", "features.denseblock2.denselayer8.norm1.weight", "features.denseblock2.denselayer8.norm1.bias", "features.denseblock2.denselayer8.norm1.running_mean", "features.denseblock2.denselayer8.norm1.running_var", "features.denseblock2.denselayer8.norm1.num_batches_tracked", "features.denseblock2.denselayer8.conv1.weight", "features.denseblock2.denselayer8.norm2.weight", "features.denseblock2.denselayer8.norm2.bias", "features.denseblock2.denselayer8.norm2.running_mean", "features.denseblock2.denselayer8.norm2.running_var", "features.denseblock2.denselayer8.norm2.num_batches_tracked", "features.denseblock2.denselayer8.conv2.weight", "features.denseblock2.denselayer9.norm1.weight", "features.denseblock2.denselayer9.norm1.bias", "features.denseblock2.denselayer9.norm1.running_mean", "features.denseblock2.denselayer9.norm1.running_var", "features.denseblock2.denselayer9.norm1.num_batches_tracked", "features.denseblock2.denselayer9.conv1.weight", "features.denseblock2.denselayer9.norm2.weight", "features.denseblock2.denselayer9.norm2.bias", "features.denseblock2.denselayer9.norm2.running_mean", "features.denseblock2.denselayer9.norm2.running_var", "features.denseblock2.denselayer9.norm2.num_batches_tracked", "features.denseblock2.denselayer9.conv2.weight", "features.denseblock2.denselayer10.norm1.weight", "features.denseblock2.denselayer10.norm1.bias", "features.denseblock2.denselayer10.norm1.running_mean", "features.denseblock2.denselayer10.norm1.running_var", "features.denseblock2.denselayer10.norm1.num_batches_tracked", "features.denseblock2.denselayer10.conv1.weight", "features.denseblock2.denselayer10.norm2.weight", "features.denseblock2.denselayer10.norm2.bias", "features.denseblock2.denselayer10.norm2.running_mean", "features.denseblock2.denselayer10.norm2.running_var", "features.denseblock2.denselayer10.norm2.num_batches_tracked", "features.denseblock2.denselayer10.conv2.weight", "features.denseblock2.denselayer11.norm1.weight", "features.denseblock2.denselayer11.norm1.bias", "features.denseblock2.denselayer11.norm1.running_mean", "features.denseblock2.denselayer11.norm1.running_var", "features.denseblock2.denselayer11.norm1.num_batches_tracked", "features.denseblock2.denselayer11.conv1.weight", "features.denseblock2.denselayer11.norm2.weight", "features.denseblock2.denselayer11.norm2.bias", "features.denseblock2.denselayer11.norm2.running_mean", "features.denseblock2.denselayer11.norm2.running_var", "features.denseblock2.denselayer11.norm2.num_batches_tracked", "features.denseblock2.denselayer11.conv2.weight", "features.denseblock2.denselayer12.norm1.weight", "features.denseblock2.denselayer12.norm1.bias", "features.denseblock2.denselayer12.norm1.running_mean", "features.denseblock2.denselayer12.norm1.running_var", "features.denseblock2.denselayer12.norm1.num_batches_tracked", "features.denseblock2.denselayer12.conv1.weight", "features.denseblock2.denselayer12.norm2.weight", "features.denseblock2.denselayer12.norm2.bias", "features.denseblock2.denselayer12.norm2.running_mean", "features.denseblock2.denselayer12.norm2.running_var", "features.denseblock2.denselayer12.norm2.num_batches_tracked", "features.denseblock2.denselayer12.conv2.weight", "features.transition2.norm.weight", "features.transition2.norm.bias", "features.transition2.norm.running_mean", "features.transition2.norm.running_var", "features.transition2.norm.num_batches_tracked", "features.transition2.conv.weight", "features.denseblock3.denselayer1.norm1.weight", "features.denseblock3.denselayer1.norm1.bias", "features.denseblock3.denselayer1.norm1.running_mean", "features.denseblock3.denselayer1.norm1.running_var", "features.denseblock3.denselayer1.norm1.num_batches_tracked", "features.denseblock3.denselayer1.conv1.weight", "features.denseblock3.denselayer1.norm2.weight", "features.denseblock3.denselayer1.norm2.bias", "features.denseblock3.denselayer1.norm2.running_mean", "features.denseblock3.denselayer1.norm2.running_var", "features.denseblock3.denselayer1.norm2.num_batches_tracked", "features.denseblock3.denselayer1.conv2.weight", "features.denseblock3.denselayer2.norm1.weight", "features.denseblock3.denselayer2.norm1.bias", "features.denseblock3.denselayer2.norm1.running_mean", "features.denseblock3.denselayer2.norm1.running_var", "features.denseblock3.denselayer2.norm1.num_batches_tracked", "features.denseblock3.denselayer2.conv1.weight", "features.denseblock3.denselayer2.norm2.weight", "features.denseblock3.denselayer2.norm2.bias", "features.denseblock3.denselayer2.norm2.running_mean", "features.denseblock3.denselayer2.norm2.running_var", "features.denseblock3.denselayer2.norm2.num_batches_tracked", "features.denseblock3.denselayer2.conv2.weight", "features.denseblock3.denselayer3.norm1.weight", "features.denseblock3.denselayer3.norm1.bias", "features.denseblock3.denselayer3.norm1.running_mean", "features.denseblock3.denselayer3.norm1.running_var", "features.denseblock3.denselayer3.norm1.num_batches_tracked", "features.denseblock3.denselayer3.conv1.weight", "features.denseblock3.denselayer3.norm2.weight", "features.denseblock3.denselayer3.norm2.bias", "features.denseblock3.denselayer3.norm2.running_mean", "features.denseblock3.denselayer3.norm2.running_var", "features.denseblock3.denselayer3.norm2.num_batches_tracked", "features.denseblock3.denselayer3.conv2.weight", "features.denseblock3.denselayer4.norm1.weight", "features.denseblock3.denselayer4.norm1.bias", "features.denseblock3.denselayer4.norm1.running_mean", "features.denseblock3.denselayer4.norm1.running_var", "features.denseblock3.denselayer4.norm1.num_batches_tracked", "features.denseblock3.denselayer4.conv1.weight", "features.denseblock3.denselayer4.norm2.weight", "features.denseblock3.denselayer4.norm2.bias", "features.denseblock3.denselayer4.norm2.running_mean", "features.denseblock3.denselayer4.norm2.running_var", "features.denseblock3.denselayer4.norm2.num_batches_tracked", "features.denseblock3.denselayer4.conv2.weight", "features.denseblock3.denselayer5.norm1.weight", "features.denseblock3.denselayer5.norm1.bias", "features.denseblock3.denselayer5.norm1.running_mean", "features.denseblock3.denselayer5.norm1.running_var", "features.denseblock3.denselayer5.norm1.num_batches_tracked", "features.denseblock3.denselayer5.conv1.weight", "features.denseblock3.denselayer5.norm2.weight", "features.denseblock3.denselayer5.norm2.bias", "features.denseblock3.denselayer5.norm2.running_mean", "features.denseblock3.denselayer5.norm2.running_var", "features.denseblock3.denselayer5.norm2.num_batches_tracked", "features.denseblock3.denselayer5.conv2.weight", "features.denseblock3.denselayer6.norm1.weight", "features.denseblock3.denselayer6.norm1.bias", "features.denseblock3.denselayer6.norm1.running_mean", "features.denseblock3.denselayer6.norm1.running_var", "features.denseblock3.denselayer6.norm1.num_batches_tracked", "features.denseblock3.denselayer6.conv1.weight", "features.denseblock3.denselayer6.norm2.weight", "features.denseblock3.denselayer6.norm2.bias", "features.denseblock3.denselayer6.norm2.running_mean", "features.denseblock3.denselayer6.norm2.running_var", "features.denseblock3.denselayer6.norm2.num_batches_tracked", "features.denseblock3.denselayer6.conv2.weight", "features.denseblock3.denselayer7.norm1.weight", "features.denseblock3.denselayer7.norm1.bias", "features.denseblock3.denselayer7.norm1.running_mean", "features.denseblock3.denselayer7.norm1.running_var", "features.denseblock3.denselayer7.norm1.num_batches_tracked", "features.denseblock3.denselayer7.conv1.weight", "features.denseblock3.denselayer7.norm2.weight", "features.denseblock3.denselayer7.norm2.bias", "features.denseblock3.denselayer7.norm2.running_mean", "features.denseblock3.denselayer7.norm2.running_var", "features.denseblock3.denselayer7.norm2.num_batches_tracked", "features.denseblock3.denselayer7.conv2.weight", "features.denseblock3.denselayer8.norm1.weight", "features.denseblock3.denselayer8.norm1.bias", "features.denseblock3.denselayer8.norm1.running_mean", "features.denseblock3.denselayer8.norm1.running_var", "features.denseblock3.denselayer8.norm1.num_batches_tracked", "features.denseblock3.denselayer8.conv1.weight", "features.denseblock3.denselayer8.norm2.weight", "features.denseblock3.denselayer8.norm2.bias", "features.denseblock3.denselayer8.norm2.running_mean", "features.denseblock3.denselayer8.norm2.running_var", "features.denseblock3.denselayer8.norm2.num_batches_tracked", "features.denseblock3.denselayer8.conv2.weight", "features.denseblock3.denselayer9.norm1.weight", "features.denseblock3.denselayer9.norm1.bias", "features.denseblock3.denselayer9.norm1.running_mean", "features.denseblock3.denselayer9.norm1.running_var", "features.denseblock3.denselayer9.norm1.num_batches_tracked", "features.denseblock3.denselayer9.conv1.weight", "features.denseblock3.denselayer9.norm2.weight", "features.denseblock3.denselayer9.norm2.bias", "features.denseblock3.denselayer9.norm2.running_mean", "features.denseblock3.denselayer9.norm2.running_var", "features.denseblock3.denselayer9.norm2.num_batches_tracked", "features.denseblock3.denselayer9.conv2.weight", "features.denseblock3.denselayer10.norm1.weight", "features.denseblock3.denselayer10.norm1.bias", "features.denseblock3.denselayer10.norm1.running_mean", "features.denseblock3.denselayer10.norm1.running_var", "features.denseblock3.denselayer10.norm1.num_batches_tracked", "features.denseblock3.denselayer10.conv1.weight", "features.denseblock3.denselayer10.norm2.weight", "features.denseblock3.denselayer10.norm2.bias", "features.denseblock3.denselayer10.norm2.running_mean", "features.denseblock3.denselayer10.norm2.running_var", "features.denseblock3.denselayer10.norm2.num_batches_tracked", "features.denseblock3.denselayer10.conv2.weight", "features.denseblock3.denselayer11.norm1.weight", "features.denseblock3.denselayer11.norm1.bias", "features.denseblock3.denselayer11.norm1.running_mean", "features.denseblock3.denselayer11.norm1.running_var", "features.denseblock3.denselayer11.norm1.num_batches_tracked", "features.denseblock3.denselayer11.conv1.weight", "features.denseblock3.denselayer11.norm2.weight", "features.denseblock3.denselayer11.norm2.bias", "features.denseblock3.denselayer11.norm2.running_mean", "features.denseblock3.denselayer11.norm2.running_var", "features.denseblock3.denselayer11.norm2.num_batches_tracked", "features.denseblock3.denselayer11.conv2.weight", "features.denseblock3.denselayer12.norm1.weight", "features.denseblock3.denselayer12.norm1.bias", "features.denseblock3.denselayer12.norm1.running_mean", "features.denseblock3.denselayer12.norm1.running_var", "features.denseblock3.denselayer12.norm1.num_batches_tracked", "features.denseblock3.denselayer12.conv1.weight", "features.denseblock3.denselayer12.norm2.weight", "features.denseblock3.denselayer12.norm2.bias", "features.denseblock3.denselayer12.norm2.running_mean", "features.denseblock3.denselayer12.norm2.running_var", "features.denseblock3.denselayer12.norm2.num_batches_tracked", "features.denseblock3.denselayer12.conv2.weight", "features.denseblock3.denselayer13.norm1.weight", "features.denseblock3.denselayer13.norm1.bias", "features.denseblock3.denselayer13.norm1.running_mean", "features.denseblock3.denselayer13.norm1.running_var", "features.denseblock3.denselayer13.norm1.num_batches_tracked", "features.denseblock3.denselayer13.conv1.weight", "features.denseblock3.denselayer13.norm2.weight", "features.denseblock3.denselayer13.norm2.bias", "features.denseblock3.denselayer13.norm2.running_mean", "features.denseblock3.denselayer13.norm2.running_var", "features.denseblock3.denselayer13.norm2.num_batches_tracked", "features.denseblock3.denselayer13.conv2.weight", "features.denseblock3.denselayer14.norm1.weight", "features.denseblock3.denselayer14.norm1.bias", "features.denseblock3.denselayer14.norm1.running_mean", "features.denseblock3.denselayer14.norm1.running_var", "features.denseblock3.denselayer14.norm1.num_batches_tracked", "features.denseblock3.denselayer14.conv1.weight", "features.denseblock3.denselayer14.norm2.weight", "features.denseblock3.denselayer14.norm2.bias", "features.denseblock3.denselayer14.norm2.running_mean", "features.denseblock3.denselayer14.norm2.running_var", "features.denseblock3.denselayer14.norm2.num_batches_tracked", "features.denseblock3.denselayer14.conv2.weight", "features.denseblock3.denselayer15.norm1.weight", "features.denseblock3.denselayer15.norm1.bias", "features.denseblock3.denselayer15.norm1.running_mean", "features.denseblock3.denselayer15.norm1.running_var", "features.denseblock3.denselayer15.norm1.num_batches_tracked", "features.denseblock3.denselayer15.conv1.weight", "features.denseblock3.denselayer15.norm2.weight", "features.denseblock3.denselayer15.norm2.bias", "features.denseblock3.denselayer15.norm2.running_mean", "features.denseblock3.denselayer15.norm2.running_var", "features.denseblock3.denselayer15.norm2.num_batches_tracked", "features.denseblock3.denselayer15.conv2.weight", "features.denseblock3.denselayer16.norm1.weight", "features.denseblock3.denselayer16.norm1.bias", "features.denseblock3.denselayer16.norm1.running_mean", "features.denseblock3.denselayer16.norm1.running_var", "features.denseblock3.denselayer16.norm1.num_batches_tracked", "features.denseblock3.denselayer16.conv1.weight", "features.denseblock3.denselayer16.norm2.weight", "features.denseblock3.denselayer16.norm2.bias", "features.denseblock3.denselayer16.norm2.running_mean", "features.denseblock3.denselayer16.norm2.running_var", "features.denseblock3.denselayer16.norm2.num_batches_tracked", "features.denseblock3.denselayer16.conv2.weight", "features.denseblock3.denselayer17.norm1.weight", "features.denseblock3.denselayer17.norm1.bias", "features.denseblock3.denselayer17.norm1.running_mean", "features.denseblock3.denselayer17.norm1.running_var", "features.denseblock3.denselayer17.norm1.num_batches_tracked", "features.denseblock3.denselayer17.conv1.weight", "features.denseblock3.denselayer17.norm2.weight", "features.denseblock3.denselayer17.norm2.bias", "features.denseblock3.denselayer17.norm2.running_mean", "features.denseblock3.denselayer17.norm2.running_var", "features.denseblock3.denselayer17.norm2.num_batches_tracked", "features.denseblock3.denselayer17.conv2.weight", "features.denseblock3.denselayer18.norm1.weight", "features.denseblock3.denselayer18.norm1.bias", "features.denseblock3.denselayer18.norm1.running_mean", "features.denseblock3.denselayer18.norm1.running_var", "features.denseblock3.denselayer18.norm1.num_batches_tracked", "features.denseblock3.denselayer18.conv1.weight", "features.denseblock3.denselayer18.norm2.weight", "features.denseblock3.denselayer18.norm2.bias", "features.denseblock3.denselayer18.norm2.running_mean", "features.denseblock3.denselayer18.norm2.running_var", "features.denseblock3.denselayer18.norm2.num_batches_tracked", "features.denseblock3.denselayer18.conv2.weight", "features.denseblock3.denselayer19.norm1.weight", "features.denseblock3.denselayer19.norm1.bias", "features.denseblock3.denselayer19.norm1.running_mean", "features.denseblock3.denselayer19.norm1.running_var", "features.denseblock3.denselayer19.norm1.num_batches_tracked", "features.denseblock3.denselayer19.conv1.weight", "features.denseblock3.denselayer19.norm2.weight", "features.denseblock3.denselayer19.norm2.bias", "features.denseblock3.denselayer19.norm2.running_mean", "features.denseblock3.denselayer19.norm2.running_var", "features.denseblock3.denselayer19.norm2.num_batches_tracked", "features.denseblock3.denselayer19.conv2.weight", "features.denseblock3.denselayer20.norm1.weight", "features.denseblock3.denselayer20.norm1.bias", "features.denseblock3.denselayer20.norm1.running_mean", "features.denseblock3.denselayer20.norm1.running_var", "features.denseblock3.denselayer20.norm1.num_batches_tracked", "features.denseblock3.denselayer20.conv1.weight", "features.denseblock3.denselayer20.norm2.weight", "features.denseblock3.denselayer20.norm2.bias", "features.denseblock3.denselayer20.norm2.running_mean", "features.denseblock3.denselayer20.norm2.running_var", "features.denseblock3.denselayer20.norm2.num_batches_tracked", "features.denseblock3.denselayer20.conv2.weight", "features.denseblock3.denselayer21.norm1.weight", "features.denseblock3.denselayer21.norm1.bias", "features.denseblock3.denselayer21.norm1.running_mean", "features.denseblock3.denselayer21.norm1.running_var", "features.denseblock3.denselayer21.norm1.num_batches_tracked", "features.denseblock3.denselayer21.conv1.weight", "features.denseblock3.denselayer21.norm2.weight", "features.denseblock3.denselayer21.norm2.bias", "features.denseblock3.denselayer21.norm2.running_mean", "features.denseblock3.denselayer21.norm2.running_var", "features.denseblock3.denselayer21.norm2.num_batches_tracked", "features.denseblock3.denselayer21.conv2.weight", "features.denseblock3.denselayer22.norm1.weight", "features.denseblock3.denselayer22.norm1.bias", "features.denseblock3.denselayer22.norm1.running_mean", "features.denseblock3.denselayer22.norm1.running_var", "features.denseblock3.denselayer22.norm1.num_batches_tracked", "features.denseblock3.denselayer22.conv1.weight", "features.denseblock3.denselayer22.norm2.weight", "features.denseblock3.denselayer22.norm2.bias", "features.denseblock3.denselayer22.norm2.running_mean", "features.denseblock3.denselayer22.norm2.running_var", "features.denseblock3.denselayer22.norm2.num_batches_tracked", "features.denseblock3.denselayer22.conv2.weight", "features.denseblock3.denselayer23.norm1.weight", "features.denseblock3.denselayer23.norm1.bias", "features.denseblock3.denselayer23.norm1.running_mean", "features.denseblock3.denselayer23.norm1.running_var", "features.denseblock3.denselayer23.norm1.num_batches_tracked", "features.denseblock3.denselayer23.conv1.weight", "features.denseblock3.denselayer23.norm2.weight", "features.denseblock3.denselayer23.norm2.bias", "features.denseblock3.denselayer23.norm2.running_mean", "features.denseblock3.denselayer23.norm2.running_var", "features.denseblock3.denselayer23.norm2.num_batches_tracked", "features.denseblock3.denselayer23.conv2.weight", "features.denseblock3.denselayer24.norm1.weight", "features.denseblock3.denselayer24.norm1.bias", "features.denseblock3.denselayer24.norm1.running_mean", "features.denseblock3.denselayer24.norm1.running_var", "features.denseblock3.denselayer24.norm1.num_batches_tracked", "features.denseblock3.denselayer24.conv1.weight", "features.denseblock3.denselayer24.norm2.weight", "features.denseblock3.denselayer24.norm2.bias", "features.denseblock3.denselayer24.norm2.running_mean", "features.denseblock3.denselayer24.norm2.running_var", "features.denseblock3.denselayer24.norm2.num_batches_tracked", "features.denseblock3.denselayer24.conv2.weight", "features.transition3.norm.weight", "features.transition3.norm.bias", "features.transition3.norm.running_mean", "features.transition3.norm.running_var", "features.transition3.norm.num_batches_tracked", "features.transition3.conv.weight", "features.denseblock4.denselayer1.norm1.weight", "features.denseblock4.denselayer1.norm1.bias", "features.denseblock4.denselayer1.norm1.running_mean", "features.denseblock4.denselayer1.norm1.running_var", "features.denseblock4.denselayer1.norm1.num_batches_tracked", "features.denseblock4.denselayer1.conv1.weight", "features.denseblock4.denselayer1.norm2.weight", "features.denseblock4.denselayer1.norm2.bias", "features.denseblock4.denselayer1.norm2.running_mean", "features.denseblock4.denselayer1.norm2.running_var", "features.denseblock4.denselayer1.norm2.num_batches_tracked", "features.denseblock4.denselayer1.conv2.weight", "features.denseblock4.denselayer2.norm1.weight", "features.denseblock4.denselayer2.norm1.bias", "features.denseblock4.denselayer2.norm1.running_mean", "features.denseblock4.denselayer2.norm1.running_var", "features.denseblock4.denselayer2.norm1.num_batches_tracked", "features.denseblock4.denselayer2.conv1.weight", "features.denseblock4.denselayer2.norm2.weight", "features.denseblock4.denselayer2.norm2.bias", "features.denseblock4.denselayer2.norm2.running_mean", "features.denseblock4.denselayer2.norm2.running_var", "features.denseblock4.denselayer2.norm2.num_batches_tracked", "features.denseblock4.denselayer2.conv2.weight", "features.denseblock4.denselayer3.norm1.weight", "features.denseblock4.denselayer3.norm1.bias", "features.denseblock4.denselayer3.norm1.running_mean", "features.denseblock4.denselayer3.norm1.running_var", "features.denseblock4.denselayer3.norm1.num_batches_tracked", "features.denseblock4.denselayer3.conv1.weight", "features.denseblock4.denselayer3.norm2.weight", "features.denseblock4.denselayer3.norm2.bias", "features.denseblock4.denselayer3.norm2.running_mean", "features.denseblock4.denselayer3.norm2.running_var", "features.denseblock4.denselayer3.norm2.num_batches_tracked", "features.denseblock4.denselayer3.conv2.weight", "features.denseblock4.denselayer4.norm1.weight", "features.denseblock4.denselayer4.norm1.bias", "features.denseblock4.denselayer4.norm1.running_mean", "features.denseblock4.denselayer4.norm1.running_var", "features.denseblock4.denselayer4.norm1.num_batches_tracked", "features.denseblock4.denselayer4.conv1.weight", "features.denseblock4.denselayer4.norm2.weight", "features.denseblock4.denselayer4.norm2.bias", "features.denseblock4.denselayer4.norm2.running_mean", "features.denseblock4.denselayer4.norm2.running_var", "features.denseblock4.denselayer4.norm2.num_batches_tracked", "features.denseblock4.denselayer4.conv2.weight", "features.denseblock4.denselayer5.norm1.weight", "features.denseblock4.denselayer5.norm1.bias", "features.denseblock4.denselayer5.norm1.running_mean", "features.denseblock4.denselayer5.norm1.running_var", "features.denseblock4.denselayer5.norm1.num_batches_tracked", "features.denseblock4.denselayer5.conv1.weight", "features.denseblock4.denselayer5.norm2.weight", "features.denseblock4.denselayer5.norm2.bias", "features.denseblock4.denselayer5.norm2.running_mean", "features.denseblock4.denselayer5.norm2.running_var", "features.denseblock4.denselayer5.norm2.num_batches_tracked", "features.denseblock4.denselayer5.conv2.weight", "features.denseblock4.denselayer6.norm1.weight", "features.denseblock4.denselayer6.norm1.bias", "features.denseblock4.denselayer6.norm1.running_mean", "features.denseblock4.denselayer6.norm1.running_var", "features.denseblock4.denselayer6.norm1.num_batches_tracked", "features.denseblock4.denselayer6.conv1.weight", "features.denseblock4.denselayer6.norm2.weight", "features.denseblock4.denselayer6.norm2.bias", "features.denseblock4.denselayer6.norm2.running_mean", "features.denseblock4.denselayer6.norm2.running_var", "features.denseblock4.denselayer6.norm2.num_batches_tracked", "features.denseblock4.denselayer6.conv2.weight", "features.denseblock4.denselayer7.norm1.weight", "features.denseblock4.denselayer7.norm1.bias", "features.denseblock4.denselayer7.norm1.running_mean", "features.denseblock4.denselayer7.norm1.running_var", "features.denseblock4.denselayer7.norm1.num_batches_tracked", "features.denseblock4.denselayer7.conv1.weight", "features.denseblock4.denselayer7.norm2.weight", "features.denseblock4.denselayer7.norm2.bias", "features.denseblock4.denselayer7.norm2.running_mean", "features.denseblock4.denselayer7.norm2.running_var", "features.denseblock4.denselayer7.norm2.num_batches_tracked", "features.denseblock4.denselayer7.conv2.weight", "features.denseblock4.denselayer8.norm1.weight", "features.denseblock4.denselayer8.norm1.bias", "features.denseblock4.denselayer8.norm1.running_mean", "features.denseblock4.denselayer8.norm1.running_var", "features.denseblock4.denselayer8.norm1.num_batches_tracked", "features.denseblock4.denselayer8.conv1.weight", "features.denseblock4.denselayer8.norm2.weight", "features.denseblock4.denselayer8.norm2.bias", "features.denseblock4.denselayer8.norm2.running_mean", "features.denseblock4.denselayer8.norm2.running_var", "features.denseblock4.denselayer8.norm2.num_batches_tracked", "features.denseblock4.denselayer8.conv2.weight", "features.denseblock4.denselayer9.norm1.weight", "features.denseblock4.denselayer9.norm1.bias", "features.denseblock4.denselayer9.norm1.running_mean", "features.denseblock4.denselayer9.norm1.running_var", "features.denseblock4.denselayer9.norm1.num_batches_tracked", "features.denseblock4.denselayer9.conv1.weight", "features.denseblock4.denselayer9.norm2.weight", "features.denseblock4.denselayer9.norm2.bias", "features.denseblock4.denselayer9.norm2.running_mean", "features.denseblock4.denselayer9.norm2.running_var", "features.denseblock4.denselayer9.norm2.num_batches_tracked", "features.denseblock4.denselayer9.conv2.weight", "features.denseblock4.denselayer10.norm1.weight", "features.denseblock4.denselayer10.norm1.bias", "features.denseblock4.denselayer10.norm1.running_mean", "features.denseblock4.denselayer10.norm1.running_var", "features.denseblock4.denselayer10.norm1.num_batches_tracked", "features.denseblock4.denselayer10.conv1.weight", "features.denseblock4.denselayer10.norm2.weight", "features.denseblock4.denselayer10.norm2.bias", "features.denseblock4.denselayer10.norm2.running_mean", "features.denseblock4.denselayer10.norm2.running_var", "features.denseblock4.denselayer10.norm2.num_batches_tracked", "features.denseblock4.denselayer10.conv2.weight", "features.denseblock4.denselayer11.norm1.weight", "features.denseblock4.denselayer11.norm1.bias", "features.denseblock4.denselayer11.norm1.running_mean", "features.denseblock4.denselayer11.norm1.running_var", "features.denseblock4.denselayer11.norm1.num_batches_tracked", "features.denseblock4.denselayer11.conv1.weight", "features.denseblock4.denselayer11.norm2.weight", "features.denseblock4.denselayer11.norm2.bias", "features.denseblock4.denselayer11.norm2.running_mean", "features.denseblock4.denselayer11.norm2.running_var", "features.denseblock4.denselayer11.norm2.num_batches_tracked", "features.denseblock4.denselayer11.conv2.weight", "features.denseblock4.denselayer12.norm1.weight", "features.denseblock4.denselayer12.norm1.bias", "features.denseblock4.denselayer12.norm1.running_mean", "features.denseblock4.denselayer12.norm1.running_var", "features.denseblock4.denselayer12.norm1.num_batches_tracked", "features.denseblock4.denselayer12.conv1.weight", "features.denseblock4.denselayer12.norm2.weight", "features.denseblock4.denselayer12.norm2.bias", "features.denseblock4.denselayer12.norm2.running_mean", "features.denseblock4.denselayer12.norm2.running_var", "features.denseblock4.denselayer12.norm2.num_batches_tracked", "features.denseblock4.denselayer12.conv2.weight", "features.denseblock4.denselayer13.norm1.weight", "features.denseblock4.denselayer13.norm1.bias", "features.denseblock4.denselayer13.norm1.running_mean", "features.denseblock4.denselayer13.norm1.running_var", "features.denseblock4.denselayer13.norm1.num_batches_tracked", "features.denseblock4.denselayer13.conv1.weight", "features.denseblock4.denselayer13.norm2.weight", "features.denseblock4.denselayer13.norm2.bias", "features.denseblock4.denselayer13.norm2.running_mean", "features.denseblock4.denselayer13.norm2.running_var", "features.denseblock4.denselayer13.norm2.num_batches_tracked", "features.denseblock4.denselayer13.conv2.weight", "features.denseblock4.denselayer14.norm1.weight", "features.denseblock4.denselayer14.norm1.bias", "features.denseblock4.denselayer14.norm1.running_mean", "features.denseblock4.denselayer14.norm1.running_var", "features.denseblock4.denselayer14.norm1.num_batches_tracked", "features.denseblock4.denselayer14.conv1.weight", "features.denseblock4.denselayer14.norm2.weight", "features.denseblock4.denselayer14.norm2.bias", "features.denseblock4.denselayer14.norm2.running_mean", "features.denseblock4.denselayer14.norm2.running_var", "features.denseblock4.denselayer14.norm2.num_batches_tracked", "features.denseblock4.denselayer14.conv2.weight", "features.denseblock4.denselayer15.norm1.weight", "features.denseblock4.denselayer15.norm1.bias", "features.denseblock4.denselayer15.norm1.running_mean", "features.denseblock4.denselayer15.norm1.running_var", "features.denseblock4.denselayer15.norm1.num_batches_tracked", "features.denseblock4.denselayer15.conv1.weight", "features.denseblock4.denselayer15.norm2.weight", "features.denseblock4.denselayer15.norm2.bias", "features.denseblock4.denselayer15.norm2.running_mean", "features.denseblock4.denselayer15.norm2.running_var", "features.denseblock4.denselayer15.norm2.num_batches_tracked", "features.denseblock4.denselayer15.conv2.weight", "features.denseblock4.denselayer16.norm1.weight", "features.denseblock4.denselayer16.norm1.bias", "features.denseblock4.denselayer16.norm1.running_mean", "features.denseblock4.denselayer16.norm1.running_var", "features.denseblock4.denselayer16.norm1.num_batches_tracked", "features.denseblock4.denselayer16.conv1.weight", "features.denseblock4.denselayer16.norm2.weight", "features.denseblock4.denselayer16.norm2.bias", "features.denseblock4.denselayer16.norm2.running_mean", "features.denseblock4.denselayer16.norm2.running_var", "features.denseblock4.denselayer16.norm2.num_batches_tracked", "features.denseblock4.denselayer16.conv2.weight", "features.norm5.weight", "features.norm5.bias", "features.norm5.running_mean", "features.norm5.running_var", "features.norm5.num_batches_tracked", "classifier.weight", "classifier.bias". 

In [None]:
mean_embeds = get_class_mean_embeddings(resnetxt50_pretrained, train_loader)
print(mean_embeds)

In [None]:
def compute_distances(model, mean_embeds):
    print("l2 dist:")
    dist_iid = calc_dist_mean(model, id_val_loader, mean_embeds)
    print(calc_mean_std(dist_iid[0]))
    print(calc_mean_std(dist_iid[1]))
    dist_ood = calc_dist_mean(model, val_loader, mean_embeds)
    print(calc_mean_std(dist_ood[0]))
    print(calc_mean_std(dist_ood[1]))

    print("cos sim:")
    dist_iid = calc_dist_mean(model, id_val_loader, mean_embeds, "cos")
    print(calc_mean_std(dist_iid[0]))
    print(calc_mean_std(dist_iid[1]))
    dist_ood = calc_dist_mean(model, val_loader, mean_embeds, "cos")
    print(calc_mean_std(dist_ood[0]))
    print(calc_mean_std(dist_ood[1]))

    print("mink dist:")
    dist_iid = calc_dist_mean(model, id_val_loader, mean_embeds, "mink")
    print(calc_mean_std(dist_iid[0]))
    print(calc_mean_std(dist_iid[1]))
    dist_ood = calc_dist_mean(model, val_loader, mean_embeds, "mink")
    print(calc_mean_std(dist_ood[0]))
    print(calc_mean_std(dist_ood[1]))

In [None]:
compute_distances(resnetxt50_pretrained, mean_embeds)

In [104]:
"""
densenet121
"""
densenet121_pretrained.load_state_dict(torch.load("densenet121_pretrained_all_grad.pt"))

<All keys matched successfully>

In [105]:
mean_embeds = get_class_mean_embeddings(densenet121_pretrained, train_loader)
print(mean_embeds)

Sequential(
  (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu0): ReLU(inplace=True)
  (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (denseblock1): _DenseBlock(
    (denselayer1): _DenseLayer(
      (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (denselayer2): _DenseLayer(
      (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(96, 128, ke

  0%|          | 0/3119 [00:00<?, ?it/s]


RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[32, 1024, 3, 3] to have 3 channels, but got 1024 channels instead