In [1]:
# Custom dataset class
import os
import csv
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms
from sklearn.metrics import confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import roc_curve, roc_auc_score
import torch.nn.functional as F

# custom dataset on csv files
class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = []
        self.labels = []  # Store labels separately
        self.patient_ids = []  # Store patient IDs separately
        self.transform = transform
        
        # read csv
        with open(csv_file, 'r') as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                image_path = row['ImagePath']
                label = row['Label']
                patient_id = row['PatientID']  # Assuming 'PatientID' is the column name in your CSV file

                if label == 'NFF':
                    label = 0
                elif label == 'AFF':
                    label = 1
                else:
                    raise ValueError("Invalid label in CSV file.")
                self.data.append((image_path, label))
                self.labels.append(label)
                self.patient_ids.append(patient_id)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path, label = self.data[idx]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

# plot roc curve
def plot_roc_curve(all_labels, all_predictions):
    fpr, tpr, thresholds = roc_curve(all_labels, all_predictions)
    roc_auc = roc_auc_score(all_labels, all_predictions)

    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc='lower right')
    plt.show()

# test model performance 
def test_model(model, test_dataset, batch_size):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
    model.eval()
    
    correct_test = 0
    total_test = 0
    test_loss = 0
    
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        all_predictions = []
        all_labels = []
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            _, predicted = torch.max(outputs.data, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()
            test_loss += loss.item()

    test_accuracy = 100 * correct_test / total_test
    test_loss /= len(test_loader)

    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.2f}%")
     
    auc_score = roc_auc_score(all_labels, all_predictions)
    conf_matrix = confusion_matrix(all_labels, all_predictions)
    
    # Precision、Recall、F1 Score
    class_labels = {0: 'NFF', 1: 'AFF'}

    classification_rep = classification_report(all_labels, all_predictions, target_names=[class_labels[i] for i in range(len(class_labels))])
    print("AUC:", auc_score)
    print("Confusion Matrix:")
    print(conf_matrix)
    print("Classification Report:")
    print(classification_rep)

    plot_roc_curve(all_labels, all_predictions)
    
    

In [2]:
# my net work, the same as jobs
class res101(nn.Module):
    def __init__(self, num_classes=2):
        super(res101, self).__init__()
        self.inplanes = 64  # 初始的inplanes值
        # 定义ResNet-101的前几层，不包括最后的全连接层
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(64, 3)
        self.layer2 = self._make_layer(128, 4, stride=2)
        self.layer3 = self._make_layer(256, 23, stride=2)
        self.layer4 = self._make_layer(512, 3, stride=2)
        # 平均池化层
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # 分类层
        self.fc = nn.Linear(512 * 4, num_classes)

    def _make_layer(self, planes, blocks, stride=1):
        layers = []
        layers.append(Bottleneck(self.inplanes, planes, stride))
        self.inplanes = planes * 4  # Bottleneck的输出通道数是planes * 4
        for _ in range(1, blocks):
            layers.append(Bottleneck(self.inplanes, planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)  # 将特征图展平
        x = self.fc(x)
        return x

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = None
        if stride != 1 or inplanes != planes * Bottleneck.expansion:
            self.downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes * Bottleneck.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * Bottleneck.expansion),
            )

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
    

In [4]:
# create net
model = res101(num_classes=2)  # num_classes 2

# checkpoint
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint_path = '/local/data1/honzh073/data/admin@liu.se/transfer/1c958eef-c89e-48b5-9c20-8621eb09a56c/workspace/app_server/FL_global_model.pt'
checkpoint = torch.load(checkpoint_path, map_location=device)
print(checkpoint.keys())

# modify checkpoint keys，match model
# load model weights
state_dict = checkpoint['model']
modified_state_dict = {}
for key, value in state_dict.items():
    # modify key names，match model layers
    new_key = key.replace("model.", "")
    modified_state_dict[new_key] = value

# load weights to model
model.load_state_dict(modified_state_dict)
model = model.to(device)
model.eval()


odict_keys(['model', 'train_conf'])


RuntimeError: Error(s) in loading state_dict for res101:
	Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.conv3.weight", "layer2.0.bn3.weight", "layer2.0.bn3.bias", "layer2.0.bn3.running_mean", "layer2.0.bn3.running_var", "layer2.0.downsample.0.weight", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.1.conv1.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.conv2.weight", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.1.conv3.weight", "layer2.1.bn3.weight", "layer2.1.bn3.bias", "layer2.1.bn3.running_mean", "layer2.1.bn3.running_var", "layer2.2.conv1.weight", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.conv2.weight", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.2.conv3.weight", "layer2.2.bn3.weight", "layer2.2.bn3.bias", "layer2.2.bn3.running_mean", "layer2.2.bn3.running_var", "layer2.3.conv1.weight", "layer2.3.bn1.weight", "layer2.3.bn1.bias", "layer2.3.bn1.running_mean", "layer2.3.bn1.running_var", "layer2.3.conv2.weight", "layer2.3.bn2.weight", "layer2.3.bn2.bias", "layer2.3.bn2.running_mean", "layer2.3.bn2.running_var", "layer2.3.conv3.weight", "layer2.3.bn3.weight", "layer2.3.bn3.bias", "layer2.3.bn3.running_mean", "layer2.3.bn3.running_var", "layer3.0.conv1.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.conv2.weight", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.conv3.weight", "layer3.0.bn3.weight", "layer3.0.bn3.bias", "layer3.0.bn3.running_mean", "layer3.0.bn3.running_var", "layer3.0.downsample.0.weight", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "layer3.1.conv1.weight", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.conv2.weight", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.1.conv3.weight", "layer3.1.bn3.weight", "layer3.1.bn3.bias", "layer3.1.bn3.running_mean", "layer3.1.bn3.running_var", "layer3.2.conv1.weight", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.conv2.weight", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.2.conv3.weight", "layer3.2.bn3.weight", "layer3.2.bn3.bias", "layer3.2.bn3.running_mean", "layer3.2.bn3.running_var", "layer3.3.conv1.weight", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.conv2.weight", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.3.conv3.weight", "layer3.3.bn3.weight", "layer3.3.bn3.bias", "layer3.3.bn3.running_mean", "layer3.3.bn3.running_var", "layer3.4.conv1.weight", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.conv2.weight", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.4.conv3.weight", "layer3.4.bn3.weight", "layer3.4.bn3.bias", "layer3.4.bn3.running_mean", "layer3.4.bn3.running_var", "layer3.5.conv1.weight", "layer3.5.bn1.weight", "layer3.5.bn1.bias", "layer3.5.bn1.running_mean", "layer3.5.bn1.running_var", "layer3.5.conv2.weight", "layer3.5.bn2.weight", "layer3.5.bn2.bias", "layer3.5.bn2.running_mean", "layer3.5.bn2.running_var", "layer3.5.conv3.weight", "layer3.5.bn3.weight", "layer3.5.bn3.bias", "layer3.5.bn3.running_mean", "layer3.5.bn3.running_var", "layer3.6.conv1.weight", "layer3.6.bn1.weight", "layer3.6.bn1.bias", "layer3.6.bn1.running_mean", "layer3.6.bn1.running_var", "layer3.6.conv2.weight", "layer3.6.bn2.weight", "layer3.6.bn2.bias", "layer3.6.bn2.running_mean", "layer3.6.bn2.running_var", "layer3.6.conv3.weight", "layer3.6.bn3.weight", "layer3.6.bn3.bias", "layer3.6.bn3.running_mean", "layer3.6.bn3.running_var", "layer3.7.conv1.weight", "layer3.7.bn1.weight", "layer3.7.bn1.bias", "layer3.7.bn1.running_mean", "layer3.7.bn1.running_var", "layer3.7.conv2.weight", "layer3.7.bn2.weight", "layer3.7.bn2.bias", "layer3.7.bn2.running_mean", "layer3.7.bn2.running_var", "layer3.7.conv3.weight", "layer3.7.bn3.weight", "layer3.7.bn3.bias", "layer3.7.bn3.running_mean", "layer3.7.bn3.running_var", "layer3.8.conv1.weight", "layer3.8.bn1.weight", "layer3.8.bn1.bias", "layer3.8.bn1.running_mean", "layer3.8.bn1.running_var", "layer3.8.conv2.weight", "layer3.8.bn2.weight", "layer3.8.bn2.bias", "layer3.8.bn2.running_mean", "layer3.8.bn2.running_var", "layer3.8.conv3.weight", "layer3.8.bn3.weight", "layer3.8.bn3.bias", "layer3.8.bn3.running_mean", "layer3.8.bn3.running_var", "layer3.9.conv1.weight", "layer3.9.bn1.weight", "layer3.9.bn1.bias", "layer3.9.bn1.running_mean", "layer3.9.bn1.running_var", "layer3.9.conv2.weight", "layer3.9.bn2.weight", "layer3.9.bn2.bias", "layer3.9.bn2.running_mean", "layer3.9.bn2.running_var", "layer3.9.conv3.weight", "layer3.9.bn3.weight", "layer3.9.bn3.bias", "layer3.9.bn3.running_mean", "layer3.9.bn3.running_var", "layer3.10.conv1.weight", "layer3.10.bn1.weight", "layer3.10.bn1.bias", "layer3.10.bn1.running_mean", "layer3.10.bn1.running_var", "layer3.10.conv2.weight", "layer3.10.bn2.weight", "layer3.10.bn2.bias", "layer3.10.bn2.running_mean", "layer3.10.bn2.running_var", "layer3.10.conv3.weight", "layer3.10.bn3.weight", "layer3.10.bn3.bias", "layer3.10.bn3.running_mean", "layer3.10.bn3.running_var", "layer3.11.conv1.weight", "layer3.11.bn1.weight", "layer3.11.bn1.bias", "layer3.11.bn1.running_mean", "layer3.11.bn1.running_var", "layer3.11.conv2.weight", "layer3.11.bn2.weight", "layer3.11.bn2.bias", "layer3.11.bn2.running_mean", "layer3.11.bn2.running_var", "layer3.11.conv3.weight", "layer3.11.bn3.weight", "layer3.11.bn3.bias", "layer3.11.bn3.running_mean", "layer3.11.bn3.running_var", "layer3.12.conv1.weight", "layer3.12.bn1.weight", "layer3.12.bn1.bias", "layer3.12.bn1.running_mean", "layer3.12.bn1.running_var", "layer3.12.conv2.weight", "layer3.12.bn2.weight", "layer3.12.bn2.bias", "layer3.12.bn2.running_mean", "layer3.12.bn2.running_var", "layer3.12.conv3.weight", "layer3.12.bn3.weight", "layer3.12.bn3.bias", "layer3.12.bn3.running_mean", "layer3.12.bn3.running_var", "layer3.13.conv1.weight", "layer3.13.bn1.weight", "layer3.13.bn1.bias", "layer3.13.bn1.running_mean", "layer3.13.bn1.running_var", "layer3.13.conv2.weight", "layer3.13.bn2.weight", "layer3.13.bn2.bias", "layer3.13.bn2.running_mean", "layer3.13.bn2.running_var", "layer3.13.conv3.weight", "layer3.13.bn3.weight", "layer3.13.bn3.bias", "layer3.13.bn3.running_mean", "layer3.13.bn3.running_var", "layer3.14.conv1.weight", "layer3.14.bn1.weight", "layer3.14.bn1.bias", "layer3.14.bn1.running_mean", "layer3.14.bn1.running_var", "layer3.14.conv2.weight", "layer3.14.bn2.weight", "layer3.14.bn2.bias", "layer3.14.bn2.running_mean", "layer3.14.bn2.running_var", "layer3.14.conv3.weight", "layer3.14.bn3.weight", "layer3.14.bn3.bias", "layer3.14.bn3.running_mean", "layer3.14.bn3.running_var", "layer3.15.conv1.weight", "layer3.15.bn1.weight", "layer3.15.bn1.bias", "layer3.15.bn1.running_mean", "layer3.15.bn1.running_var", "layer3.15.conv2.weight", "layer3.15.bn2.weight", "layer3.15.bn2.bias", "layer3.15.bn2.running_mean", "layer3.15.bn2.running_var", "layer3.15.conv3.weight", "layer3.15.bn3.weight", "layer3.15.bn3.bias", "layer3.15.bn3.running_mean", "layer3.15.bn3.running_var", "layer3.16.conv1.weight", "layer3.16.bn1.weight", "layer3.16.bn1.bias", "layer3.16.bn1.running_mean", "layer3.16.bn1.running_var", "layer3.16.conv2.weight", "layer3.16.bn2.weight", "layer3.16.bn2.bias", "layer3.16.bn2.running_mean", "layer3.16.bn2.running_var", "layer3.16.conv3.weight", "layer3.16.bn3.weight", "layer3.16.bn3.bias", "layer3.16.bn3.running_mean", "layer3.16.bn3.running_var", "layer3.17.conv1.weight", "layer3.17.bn1.weight", "layer3.17.bn1.bias", "layer3.17.bn1.running_mean", "layer3.17.bn1.running_var", "layer3.17.conv2.weight", "layer3.17.bn2.weight", "layer3.17.bn2.bias", "layer3.17.bn2.running_mean", "layer3.17.bn2.running_var", "layer3.17.conv3.weight", "layer3.17.bn3.weight", "layer3.17.bn3.bias", "layer3.17.bn3.running_mean", "layer3.17.bn3.running_var", "layer3.18.conv1.weight", "layer3.18.bn1.weight", "layer3.18.bn1.bias", "layer3.18.bn1.running_mean", "layer3.18.bn1.running_var", "layer3.18.conv2.weight", "layer3.18.bn2.weight", "layer3.18.bn2.bias", "layer3.18.bn2.running_mean", "layer3.18.bn2.running_var", "layer3.18.conv3.weight", "layer3.18.bn3.weight", "layer3.18.bn3.bias", "layer3.18.bn3.running_mean", "layer3.18.bn3.running_var", "layer3.19.conv1.weight", "layer3.19.bn1.weight", "layer3.19.bn1.bias", "layer3.19.bn1.running_mean", "layer3.19.bn1.running_var", "layer3.19.conv2.weight", "layer3.19.bn2.weight", "layer3.19.bn2.bias", "layer3.19.bn2.running_mean", "layer3.19.bn2.running_var", "layer3.19.conv3.weight", "layer3.19.bn3.weight", "layer3.19.bn3.bias", "layer3.19.bn3.running_mean", "layer3.19.bn3.running_var", "layer3.20.conv1.weight", "layer3.20.bn1.weight", "layer3.20.bn1.bias", "layer3.20.bn1.running_mean", "layer3.20.bn1.running_var", "layer3.20.conv2.weight", "layer3.20.bn2.weight", "layer3.20.bn2.bias", "layer3.20.bn2.running_mean", "layer3.20.bn2.running_var", "layer3.20.conv3.weight", "layer3.20.bn3.weight", "layer3.20.bn3.bias", "layer3.20.bn3.running_mean", "layer3.20.bn3.running_var", "layer3.21.conv1.weight", "layer3.21.bn1.weight", "layer3.21.bn1.bias", "layer3.21.bn1.running_mean", "layer3.21.bn1.running_var", "layer3.21.conv2.weight", "layer3.21.bn2.weight", "layer3.21.bn2.bias", "layer3.21.bn2.running_mean", "layer3.21.bn2.running_var", "layer3.21.conv3.weight", "layer3.21.bn3.weight", "layer3.21.bn3.bias", "layer3.21.bn3.running_mean", "layer3.21.bn3.running_var", "layer3.22.conv1.weight", "layer3.22.bn1.weight", "layer3.22.bn1.bias", "layer3.22.bn1.running_mean", "layer3.22.bn1.running_var", "layer3.22.conv2.weight", "layer3.22.bn2.weight", "layer3.22.bn2.bias", "layer3.22.bn2.running_mean", "layer3.22.bn2.running_var", "layer3.22.conv3.weight", "layer3.22.bn3.weight", "layer3.22.bn3.bias", "layer3.22.bn3.running_mean", "layer3.22.bn3.running_var", "layer4.0.conv1.weight", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.conv2.weight", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.conv3.weight", "layer4.0.bn3.weight", "layer4.0.bn3.bias", "layer4.0.bn3.running_mean", "layer4.0.bn3.running_var", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.1.conv1.weight", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.conv2.weight", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.conv3.weight", "layer4.1.bn3.weight", "layer4.1.bn3.bias", "layer4.1.bn3.running_mean", "layer4.1.bn3.running_var", "layer4.2.conv1.weight", "layer4.2.bn1.weight", "layer4.2.bn1.bias", "layer4.2.bn1.running_mean", "layer4.2.bn1.running_var", "layer4.2.conv2.weight", "layer4.2.bn2.weight", "layer4.2.bn2.bias", "layer4.2.bn2.running_mean", "layer4.2.bn2.running_var", "layer4.2.conv3.weight", "layer4.2.bn3.weight", "layer4.2.bn3.bias", "layer4.2.bn3.running_mean", "layer4.2.bn3.running_var". 
	Unexpected key(s) in state_dict: "features.conv1.weight", "features.bn1.weight", "features.bn1.bias", "features.bn1.running_mean", "features.bn1.running_var", "features.bn1.num_batches_tracked", "features.layer1.0.conv1.weight", "features.layer1.0.bn1.weight", "features.layer1.0.bn1.bias", "features.layer1.0.bn1.running_mean", "features.layer1.0.bn1.running_var", "features.layer1.0.bn1.num_batches_tracked", "features.layer1.0.conv2.weight", "features.layer1.0.bn2.weight", "features.layer1.0.bn2.bias", "features.layer1.0.bn2.running_mean", "features.layer1.0.bn2.running_var", "features.layer1.0.bn2.num_batches_tracked", "features.layer1.0.conv3.weight", "features.layer1.0.bn3.weight", "features.layer1.0.bn3.bias", "features.layer1.0.bn3.running_mean", "features.layer1.0.bn3.running_var", "features.layer1.0.bn3.num_batches_tracked", "features.layer1.0.downsample.0.weight", "features.layer1.0.downsample.1.weight", "features.layer1.0.downsample.1.bias", "features.layer1.0.downsample.1.running_mean", "features.layer1.0.downsample.1.running_var", "features.layer1.0.downsample.1.num_batches_tracked", "features.layer1.1.conv1.weight", "features.layer1.1.bn1.weight", "features.layer1.1.bn1.bias", "features.layer1.1.bn1.running_mean", "features.layer1.1.bn1.running_var", "features.layer1.1.bn1.num_batches_tracked", "features.layer1.1.conv2.weight", "features.layer1.1.bn2.weight", "features.layer1.1.bn2.bias", "features.layer1.1.bn2.running_mean", "features.layer1.1.bn2.running_var", "features.layer1.1.bn2.num_batches_tracked", "features.layer1.1.conv3.weight", "features.layer1.1.bn3.weight", "features.layer1.1.bn3.bias", "features.layer1.1.bn3.running_mean", "features.layer1.1.bn3.running_var", "features.layer1.1.bn3.num_batches_tracked", "features.layer1.2.conv1.weight", "features.layer1.2.bn1.weight", "features.layer1.2.bn1.bias", "features.layer1.2.bn1.running_mean", "features.layer1.2.bn1.running_var", "features.layer1.2.bn1.num_batches_tracked", "features.layer1.2.conv2.weight", "features.layer1.2.bn2.weight", "features.layer1.2.bn2.bias", "features.layer1.2.bn2.running_mean", "features.layer1.2.bn2.running_var", "features.layer1.2.bn2.num_batches_tracked", "features.layer1.2.conv3.weight", "features.layer1.2.bn3.weight", "features.layer1.2.bn3.bias", "features.layer1.2.bn3.running_mean", "features.layer1.2.bn3.running_var", "features.layer1.2.bn3.num_batches_tracked", "features.layer2.0.conv1.weight", "features.layer2.0.bn1.weight", "features.layer2.0.bn1.bias", "features.layer2.0.bn1.running_mean", "features.layer2.0.bn1.running_var", "features.layer2.0.bn1.num_batches_tracked", "features.layer2.0.conv2.weight", "features.layer2.0.bn2.weight", "features.layer2.0.bn2.bias", "features.layer2.0.bn2.running_mean", "features.layer2.0.bn2.running_var", "features.layer2.0.bn2.num_batches_tracked", "features.layer2.0.conv3.weight", "features.layer2.0.bn3.weight", "features.layer2.0.bn3.bias", "features.layer2.0.bn3.running_mean", "features.layer2.0.bn3.running_var", "features.layer2.0.bn3.num_batches_tracked", "features.layer2.0.downsample.0.weight", "features.layer2.0.downsample.1.weight", "features.layer2.0.downsample.1.bias", "features.layer2.0.downsample.1.running_mean", "features.layer2.0.downsample.1.running_var", "features.layer2.0.downsample.1.num_batches_tracked", "features.layer2.1.conv1.weight", "features.layer2.1.bn1.weight", "features.layer2.1.bn1.bias", "features.layer2.1.bn1.running_mean", "features.layer2.1.bn1.running_var", "features.layer2.1.bn1.num_batches_tracked", "features.layer2.1.conv2.weight", "features.layer2.1.bn2.weight", "features.layer2.1.bn2.bias", "features.layer2.1.bn2.running_mean", "features.layer2.1.bn2.running_var", "features.layer2.1.bn2.num_batches_tracked", "features.layer2.1.conv3.weight", "features.layer2.1.bn3.weight", "features.layer2.1.bn3.bias", "features.layer2.1.bn3.running_mean", "features.layer2.1.bn3.running_var", "features.layer2.1.bn3.num_batches_tracked", "features.layer2.2.conv1.weight", "features.layer2.2.bn1.weight", "features.layer2.2.bn1.bias", "features.layer2.2.bn1.running_mean", "features.layer2.2.bn1.running_var", "features.layer2.2.bn1.num_batches_tracked", "features.layer2.2.conv2.weight", "features.layer2.2.bn2.weight", "features.layer2.2.bn2.bias", "features.layer2.2.bn2.running_mean", "features.layer2.2.bn2.running_var", "features.layer2.2.bn2.num_batches_tracked", "features.layer2.2.conv3.weight", "features.layer2.2.bn3.weight", "features.layer2.2.bn3.bias", "features.layer2.2.bn3.running_mean", "features.layer2.2.bn3.running_var", "features.layer2.2.bn3.num_batches_tracked", "features.layer2.3.conv1.weight", "features.layer2.3.bn1.weight", "features.layer2.3.bn1.bias", "features.layer2.3.bn1.running_mean", "features.layer2.3.bn1.running_var", "features.layer2.3.bn1.num_batches_tracked", "features.layer2.3.conv2.weight", "features.layer2.3.bn2.weight", "features.layer2.3.bn2.bias", "features.layer2.3.bn2.running_mean", "features.layer2.3.bn2.running_var", "features.layer2.3.bn2.num_batches_tracked", "features.layer2.3.conv3.weight", "features.layer2.3.bn3.weight", "features.layer2.3.bn3.bias", "features.layer2.3.bn3.running_mean", "features.layer2.3.bn3.running_var", "features.layer2.3.bn3.num_batches_tracked", "features.layer3.0.conv1.weight", "features.layer3.0.bn1.weight", "features.layer3.0.bn1.bias", "features.layer3.0.bn1.running_mean", "features.layer3.0.bn1.running_var", "features.layer3.0.bn1.num_batches_tracked", "features.layer3.0.conv2.weight", "features.layer3.0.bn2.weight", "features.layer3.0.bn2.bias", "features.layer3.0.bn2.running_mean", "features.layer3.0.bn2.running_var", "features.layer3.0.bn2.num_batches_tracked", "features.layer3.0.conv3.weight", "features.layer3.0.bn3.weight", "features.layer3.0.bn3.bias", "features.layer3.0.bn3.running_mean", "features.layer3.0.bn3.running_var", "features.layer3.0.bn3.num_batches_tracked", "features.layer3.0.downsample.0.weight", "features.layer3.0.downsample.1.weight", "features.layer3.0.downsample.1.bias", "features.layer3.0.downsample.1.running_mean", "features.layer3.0.downsample.1.running_var", "features.layer3.0.downsample.1.num_batches_tracked", "features.layer3.1.conv1.weight", "features.layer3.1.bn1.weight", "features.layer3.1.bn1.bias", "features.layer3.1.bn1.running_mean", "features.layer3.1.bn1.running_var", "features.layer3.1.bn1.num_batches_tracked", "features.layer3.1.conv2.weight", "features.layer3.1.bn2.weight", "features.layer3.1.bn2.bias", "features.layer3.1.bn2.running_mean", "features.layer3.1.bn2.running_var", "features.layer3.1.bn2.num_batches_tracked", "features.layer3.1.conv3.weight", "features.layer3.1.bn3.weight", "features.layer3.1.bn3.bias", "features.layer3.1.bn3.running_mean", "features.layer3.1.bn3.running_var", "features.layer3.1.bn3.num_batches_tracked", "features.layer3.2.conv1.weight", "features.layer3.2.bn1.weight", "features.layer3.2.bn1.bias", "features.layer3.2.bn1.running_mean", "features.layer3.2.bn1.running_var", "features.layer3.2.bn1.num_batches_tracked", "features.layer3.2.conv2.weight", "features.layer3.2.bn2.weight", "features.layer3.2.bn2.bias", "features.layer3.2.bn2.running_mean", "features.layer3.2.bn2.running_var", "features.layer3.2.bn2.num_batches_tracked", "features.layer3.2.conv3.weight", "features.layer3.2.bn3.weight", "features.layer3.2.bn3.bias", "features.layer3.2.bn3.running_mean", "features.layer3.2.bn3.running_var", "features.layer3.2.bn3.num_batches_tracked", "features.layer3.3.conv1.weight", "features.layer3.3.bn1.weight", "features.layer3.3.bn1.bias", "features.layer3.3.bn1.running_mean", "features.layer3.3.bn1.running_var", "features.layer3.3.bn1.num_batches_tracked", "features.layer3.3.conv2.weight", "features.layer3.3.bn2.weight", "features.layer3.3.bn2.bias", "features.layer3.3.bn2.running_mean", "features.layer3.3.bn2.running_var", "features.layer3.3.bn2.num_batches_tracked", "features.layer3.3.conv3.weight", "features.layer3.3.bn3.weight", "features.layer3.3.bn3.bias", "features.layer3.3.bn3.running_mean", "features.layer3.3.bn3.running_var", "features.layer3.3.bn3.num_batches_tracked", "features.layer3.4.conv1.weight", "features.layer3.4.bn1.weight", "features.layer3.4.bn1.bias", "features.layer3.4.bn1.running_mean", "features.layer3.4.bn1.running_var", "features.layer3.4.bn1.num_batches_tracked", "features.layer3.4.conv2.weight", "features.layer3.4.bn2.weight", "features.layer3.4.bn2.bias", "features.layer3.4.bn2.running_mean", "features.layer3.4.bn2.running_var", "features.layer3.4.bn2.num_batches_tracked", "features.layer3.4.conv3.weight", "features.layer3.4.bn3.weight", "features.layer3.4.bn3.bias", "features.layer3.4.bn3.running_mean", "features.layer3.4.bn3.running_var", "features.layer3.4.bn3.num_batches_tracked", "features.layer3.5.conv1.weight", "features.layer3.5.bn1.weight", "features.layer3.5.bn1.bias", "features.layer3.5.bn1.running_mean", "features.layer3.5.bn1.running_var", "features.layer3.5.bn1.num_batches_tracked", "features.layer3.5.conv2.weight", "features.layer3.5.bn2.weight", "features.layer3.5.bn2.bias", "features.layer3.5.bn2.running_mean", "features.layer3.5.bn2.running_var", "features.layer3.5.bn2.num_batches_tracked", "features.layer3.5.conv3.weight", "features.layer3.5.bn3.weight", "features.layer3.5.bn3.bias", "features.layer3.5.bn3.running_mean", "features.layer3.5.bn3.running_var", "features.layer3.5.bn3.num_batches_tracked", "features.layer4.0.conv1.weight", "features.layer4.0.bn1.weight", "features.layer4.0.bn1.bias", "features.layer4.0.bn1.running_mean", "features.layer4.0.bn1.running_var", "features.layer4.0.bn1.num_batches_tracked", "features.layer4.0.conv2.weight", "features.layer4.0.bn2.weight", "features.layer4.0.bn2.bias", "features.layer4.0.bn2.running_mean", "features.layer4.0.bn2.running_var", "features.layer4.0.bn2.num_batches_tracked", "features.layer4.0.conv3.weight", "features.layer4.0.bn3.weight", "features.layer4.0.bn3.bias", "features.layer4.0.bn3.running_mean", "features.layer4.0.bn3.running_var", "features.layer4.0.bn3.num_batches_tracked", "features.layer4.0.downsample.0.weight", "features.layer4.0.downsample.1.weight", "features.layer4.0.downsample.1.bias", "features.layer4.0.downsample.1.running_mean", "features.layer4.0.downsample.1.running_var", "features.layer4.0.downsample.1.num_batches_tracked", "features.layer4.1.conv1.weight", "features.layer4.1.bn1.weight", "features.layer4.1.bn1.bias", "features.layer4.1.bn1.running_mean", "features.layer4.1.bn1.running_var", "features.layer4.1.bn1.num_batches_tracked", "features.layer4.1.conv2.weight", "features.layer4.1.bn2.weight", "features.layer4.1.bn2.bias", "features.layer4.1.bn2.running_mean", "features.layer4.1.bn2.running_var", "features.layer4.1.bn2.num_batches_tracked", "features.layer4.1.conv3.weight", "features.layer4.1.bn3.weight", "features.layer4.1.bn3.bias", "features.layer4.1.bn3.running_mean", "features.layer4.1.bn3.running_var", "features.layer4.1.bn3.num_batches_tracked", "features.layer4.2.conv1.weight", "features.layer4.2.bn1.weight", "features.layer4.2.bn1.bias", "features.layer4.2.bn1.running_mean", "features.layer4.2.bn1.running_var", "features.layer4.2.bn1.num_batches_tracked", "features.layer4.2.conv2.weight", "features.layer4.2.bn2.weight", "features.layer4.2.bn2.bias", "features.layer4.2.bn2.running_mean", "features.layer4.2.bn2.running_var", "features.layer4.2.bn2.num_batches_tracked", "features.layer4.2.conv3.weight", "features.layer4.2.bn3.weight", "features.layer4.2.bn3.bias", "features.layer4.2.bn3.running_mean", "features.layer4.2.bn3.running_var", "features.layer4.2.bn3.num_batches_tracked", "features.fc.weight", "features.fc.bias", "frozen_features.0.weight", "frozen_features.1.weight", "frozen_features.1.bias", "frozen_features.1.running_mean", "frozen_features.1.running_var", "frozen_features.1.num_batches_tracked", "frozen_features.4.0.conv1.weight", "frozen_features.4.0.bn1.weight", "frozen_features.4.0.bn1.bias", "frozen_features.4.0.bn1.running_mean", "frozen_features.4.0.bn1.running_var", "frozen_features.4.0.bn1.num_batches_tracked", "frozen_features.4.0.conv2.weight", "frozen_features.4.0.bn2.weight", "frozen_features.4.0.bn2.bias", "frozen_features.4.0.bn2.running_mean", "frozen_features.4.0.bn2.running_var", "frozen_features.4.0.bn2.num_batches_tracked", "frozen_features.4.0.conv3.weight", "frozen_features.4.0.bn3.weight", "frozen_features.4.0.bn3.bias", "frozen_features.4.0.bn3.running_mean", "frozen_features.4.0.bn3.running_var", "frozen_features.4.0.bn3.num_batches_tracked", "frozen_features.4.0.downsample.0.weight", "frozen_features.4.0.downsample.1.weight", "frozen_features.4.0.downsample.1.bias", "frozen_features.4.0.downsample.1.running_mean", "frozen_features.4.0.downsample.1.running_var", "frozen_features.4.0.downsample.1.num_batches_tracked", "frozen_features.4.1.conv1.weight", "frozen_features.4.1.bn1.weight", "frozen_features.4.1.bn1.bias", "frozen_features.4.1.bn1.running_mean", "frozen_features.4.1.bn1.running_var", "frozen_features.4.1.bn1.num_batches_tracked", "frozen_features.4.1.conv2.weight", "frozen_features.4.1.bn2.weight", "frozen_features.4.1.bn2.bias", "frozen_features.4.1.bn2.running_mean", "frozen_features.4.1.bn2.running_var", "frozen_features.4.1.bn2.num_batches_tracked", "frozen_features.4.1.conv3.weight", "frozen_features.4.1.bn3.weight", "frozen_features.4.1.bn3.bias", "frozen_features.4.1.bn3.running_mean", "frozen_features.4.1.bn3.running_var", "frozen_features.4.1.bn3.num_batches_tracked", "frozen_features.4.2.conv1.weight", "frozen_features.4.2.bn1.weight", "frozen_features.4.2.bn1.bias", "frozen_features.4.2.bn1.running_mean", "frozen_features.4.2.bn1.running_var", "frozen_features.4.2.bn1.num_batches_tracked", "frozen_features.4.2.conv2.weight", "frozen_features.4.2.bn2.weight", "frozen_features.4.2.bn2.bias", "frozen_features.4.2.bn2.running_mean", "frozen_features.4.2.bn2.running_var", "frozen_features.4.2.bn2.num_batches_tracked", "frozen_features.4.2.conv3.weight", "frozen_features.4.2.bn3.weight", "frozen_features.4.2.bn3.bias", "frozen_features.4.2.bn3.running_mean", "frozen_features.4.2.bn3.running_var", "frozen_features.4.2.bn3.num_batches_tracked", "frozen_features.5.0.conv1.weight", "frozen_features.5.0.bn1.weight", "frozen_features.5.0.bn1.bias", "frozen_features.5.0.bn1.running_mean", "frozen_features.5.0.bn1.running_var", "frozen_features.5.0.bn1.num_batches_tracked", "frozen_features.5.0.conv2.weight", "frozen_features.5.0.bn2.weight", "frozen_features.5.0.bn2.bias", "frozen_features.5.0.bn2.running_mean", "frozen_features.5.0.bn2.running_var", "frozen_features.5.0.bn2.num_batches_tracked", "frozen_features.5.0.conv3.weight", "frozen_features.5.0.bn3.weight", "frozen_features.5.0.bn3.bias", "frozen_features.5.0.bn3.running_mean", "frozen_features.5.0.bn3.running_var", "frozen_features.5.0.bn3.num_batches_tracked", "frozen_features.5.0.downsample.0.weight", "frozen_features.5.0.downsample.1.weight", "frozen_features.5.0.downsample.1.bias", "frozen_features.5.0.downsample.1.running_mean", "frozen_features.5.0.downsample.1.running_var", "frozen_features.5.0.downsample.1.num_batches_tracked", "frozen_features.5.1.conv1.weight", "frozen_features.5.1.bn1.weight", "frozen_features.5.1.bn1.bias", "frozen_features.5.1.bn1.running_mean", "frozen_features.5.1.bn1.running_var", "frozen_features.5.1.bn1.num_batches_tracked", "frozen_features.5.1.conv2.weight", "frozen_features.5.1.bn2.weight", "frozen_features.5.1.bn2.bias", "frozen_features.5.1.bn2.running_mean", "frozen_features.5.1.bn2.running_var", "frozen_features.5.1.bn2.num_batches_tracked", "frozen_features.5.1.conv3.weight", "frozen_features.5.1.bn3.weight", "frozen_features.5.1.bn3.bias", "frozen_features.5.1.bn3.running_mean", "frozen_features.5.1.bn3.running_var", "frozen_features.5.1.bn3.num_batches_tracked", "frozen_features.5.2.conv1.weight", "frozen_features.5.2.bn1.weight", "frozen_features.5.2.bn1.bias", "frozen_features.5.2.bn1.running_mean", "frozen_features.5.2.bn1.running_var", "frozen_features.5.2.bn1.num_batches_tracked", "frozen_features.5.2.conv2.weight", "frozen_features.5.2.bn2.weight", "frozen_features.5.2.bn2.bias", "frozen_features.5.2.bn2.running_mean", "frozen_features.5.2.bn2.running_var", "frozen_features.5.2.bn2.num_batches_tracked", "frozen_features.5.2.conv3.weight", "frozen_features.5.2.bn3.weight", "frozen_features.5.2.bn3.bias", "frozen_features.5.2.bn3.running_mean", "frozen_features.5.2.bn3.running_var", "frozen_features.5.2.bn3.num_batches_tracked", "frozen_features.5.3.conv1.weight", "frozen_features.5.3.bn1.weight", "frozen_features.5.3.bn1.bias", "frozen_features.5.3.bn1.running_mean", "frozen_features.5.3.bn1.running_var", "frozen_features.5.3.bn1.num_batches_tracked", "frozen_features.5.3.conv2.weight", "frozen_features.5.3.bn2.weight", "frozen_features.5.3.bn2.bias", "frozen_features.5.3.bn2.running_mean", "frozen_features.5.3.bn2.running_var", "frozen_features.5.3.bn2.num_batches_tracked", "frozen_features.5.3.conv3.weight", "frozen_features.5.3.bn3.weight", "frozen_features.5.3.bn3.bias", "frozen_features.5.3.bn3.running_mean", "frozen_features.5.3.bn3.running_var", "frozen_features.5.3.bn3.num_batches_tracked", "frozen_features.6.0.conv1.weight", "frozen_features.6.0.bn1.weight", "frozen_features.6.0.bn1.bias", "frozen_features.6.0.bn1.running_mean", "frozen_features.6.0.bn1.running_var", "frozen_features.6.0.bn1.num_batches_tracked", "frozen_features.6.0.conv2.weight", "frozen_features.6.0.bn2.weight", "frozen_features.6.0.bn2.bias", "frozen_features.6.0.bn2.running_mean", "frozen_features.6.0.bn2.running_var", "frozen_features.6.0.bn2.num_batches_tracked", "frozen_features.6.0.conv3.weight", "frozen_features.6.0.bn3.weight", "frozen_features.6.0.bn3.bias", "frozen_features.6.0.bn3.running_mean", "frozen_features.6.0.bn3.running_var", "frozen_features.6.0.bn3.num_batches_tracked", "frozen_features.6.0.downsample.0.weight", "frozen_features.6.0.downsample.1.weight", "frozen_features.6.0.downsample.1.bias", "frozen_features.6.0.downsample.1.running_mean", "frozen_features.6.0.downsample.1.running_var", "frozen_features.6.0.downsample.1.num_batches_tracked", "frozen_features.6.1.conv1.weight", "frozen_features.6.1.bn1.weight", "frozen_features.6.1.bn1.bias", "frozen_features.6.1.bn1.running_mean", "frozen_features.6.1.bn1.running_var", "frozen_features.6.1.bn1.num_batches_tracked", "frozen_features.6.1.conv2.weight", "frozen_features.6.1.bn2.weight", "frozen_features.6.1.bn2.bias", "frozen_features.6.1.bn2.running_mean", "frozen_features.6.1.bn2.running_var", "frozen_features.6.1.bn2.num_batches_tracked", "frozen_features.6.1.conv3.weight", "frozen_features.6.1.bn3.weight", "frozen_features.6.1.bn3.bias", "frozen_features.6.1.bn3.running_mean", "frozen_features.6.1.bn3.running_var", "frozen_features.6.1.bn3.num_batches_tracked", "frozen_features.6.2.conv1.weight", "frozen_features.6.2.bn1.weight", "frozen_features.6.2.bn1.bias", "frozen_features.6.2.bn1.running_mean", "frozen_features.6.2.bn1.running_var", "frozen_features.6.2.bn1.num_batches_tracked", "frozen_features.6.2.conv2.weight", "frozen_features.6.2.bn2.weight", "frozen_features.6.2.bn2.bias", "frozen_features.6.2.bn2.running_mean", "frozen_features.6.2.bn2.running_var", "frozen_features.6.2.bn2.num_batches_tracked", "frozen_features.6.2.conv3.weight", "frozen_features.6.2.bn3.weight", "frozen_features.6.2.bn3.bias", "frozen_features.6.2.bn3.running_mean", "frozen_features.6.2.bn3.running_var", "frozen_features.6.2.bn3.num_batches_tracked", "frozen_features.6.3.conv1.weight", "frozen_features.6.3.bn1.weight", "frozen_features.6.3.bn1.bias", "frozen_features.6.3.bn1.running_mean", "frozen_features.6.3.bn1.running_var", "frozen_features.6.3.bn1.num_batches_tracked", "frozen_features.6.3.conv2.weight", "frozen_features.6.3.bn2.weight", "frozen_features.6.3.bn2.bias", "frozen_features.6.3.bn2.running_mean", "frozen_features.6.3.bn2.running_var", "frozen_features.6.3.bn2.num_batches_tracked", "frozen_features.6.3.conv3.weight", "frozen_features.6.3.bn3.weight", "frozen_features.6.3.bn3.bias", "frozen_features.6.3.bn3.running_mean", "frozen_features.6.3.bn3.running_var", "frozen_features.6.3.bn3.num_batches_tracked", "frozen_features.6.4.conv1.weight", "frozen_features.6.4.bn1.weight", "frozen_features.6.4.bn1.bias", "frozen_features.6.4.bn1.running_mean", "frozen_features.6.4.bn1.running_var", "frozen_features.6.4.bn1.num_batches_tracked", "frozen_features.6.4.conv2.weight", "frozen_features.6.4.bn2.weight", "frozen_features.6.4.bn2.bias", "frozen_features.6.4.bn2.running_mean", "frozen_features.6.4.bn2.running_var", "frozen_features.6.4.bn2.num_batches_tracked", "frozen_features.6.4.conv3.weight", "frozen_features.6.4.bn3.weight", "frozen_features.6.4.bn3.bias", "frozen_features.6.4.bn3.running_mean", "frozen_features.6.4.bn3.running_var", "frozen_features.6.4.bn3.num_batches_tracked", "frozen_features.6.5.conv1.weight", "frozen_features.6.5.bn1.weight", "frozen_features.6.5.bn1.bias", "frozen_features.6.5.bn1.running_mean", "frozen_features.6.5.bn1.running_var", "frozen_features.6.5.bn1.num_batches_tracked", "frozen_features.6.5.conv2.weight", "frozen_features.6.5.bn2.weight", "frozen_features.6.5.bn2.bias", "frozen_features.6.5.bn2.running_mean", "frozen_features.6.5.bn2.running_var", "frozen_features.6.5.bn2.num_batches_tracked", "frozen_features.6.5.conv3.weight", "frozen_features.6.5.bn3.weight", "frozen_features.6.5.bn3.bias", "frozen_features.6.5.bn3.running_mean", "frozen_features.6.5.bn3.running_var", "frozen_features.6.5.bn3.num_batches_tracked", "unfrozen_features.0.0.conv1.weight", "unfrozen_features.0.0.bn1.weight", "unfrozen_features.0.0.bn1.bias", "unfrozen_features.0.0.bn1.running_mean", "unfrozen_features.0.0.bn1.running_var", "unfrozen_features.0.0.bn1.num_batches_tracked", "unfrozen_features.0.0.conv2.weight", "unfrozen_features.0.0.bn2.weight", "unfrozen_features.0.0.bn2.bias", "unfrozen_features.0.0.bn2.running_mean", "unfrozen_features.0.0.bn2.running_var", "unfrozen_features.0.0.bn2.num_batches_tracked", "unfrozen_features.0.0.conv3.weight", "unfrozen_features.0.0.bn3.weight", "unfrozen_features.0.0.bn3.bias", "unfrozen_features.0.0.bn3.running_mean", "unfrozen_features.0.0.bn3.running_var", "unfrozen_features.0.0.bn3.num_batches_tracked", "unfrozen_features.0.0.downsample.0.weight", "unfrozen_features.0.0.downsample.1.weight", "unfrozen_features.0.0.downsample.1.bias", "unfrozen_features.0.0.downsample.1.running_mean", "unfrozen_features.0.0.downsample.1.running_var", "unfrozen_features.0.0.downsample.1.num_batches_tracked", "unfrozen_features.0.1.conv1.weight", "unfrozen_features.0.1.bn1.weight", "unfrozen_features.0.1.bn1.bias", "unfrozen_features.0.1.bn1.running_mean", "unfrozen_features.0.1.bn1.running_var", "unfrozen_features.0.1.bn1.num_batches_tracked", "unfrozen_features.0.1.conv2.weight", "unfrozen_features.0.1.bn2.weight", "unfrozen_features.0.1.bn2.bias", "unfrozen_features.0.1.bn2.running_mean", "unfrozen_features.0.1.bn2.running_var", "unfrozen_features.0.1.bn2.num_batches_tracked", "unfrozen_features.0.1.conv3.weight", "unfrozen_features.0.1.bn3.weight", "unfrozen_features.0.1.bn3.bias", "unfrozen_features.0.1.bn3.running_mean", "unfrozen_features.0.1.bn3.running_var", "unfrozen_features.0.1.bn3.num_batches_tracked", "unfrozen_features.0.2.conv1.weight", "unfrozen_features.0.2.bn1.weight", "unfrozen_features.0.2.bn1.bias", "unfrozen_features.0.2.bn1.running_mean", "unfrozen_features.0.2.bn1.running_var", "unfrozen_features.0.2.bn1.num_batches_tracked", "unfrozen_features.0.2.conv2.weight", "unfrozen_features.0.2.bn2.weight", "unfrozen_features.0.2.bn2.bias", "unfrozen_features.0.2.bn2.running_mean", "unfrozen_features.0.2.bn2.running_var", "unfrozen_features.0.2.bn2.num_batches_tracked", "unfrozen_features.0.2.conv3.weight", "unfrozen_features.0.2.bn3.weight", "unfrozen_features.0.2.bn3.bias", "unfrozen_features.0.2.bn3.running_mean", "unfrozen_features.0.2.bn3.running_var", "unfrozen_features.0.2.bn3.num_batches_tracked". 

# Test on hospital 21, 

In [None]:
# Resnet 101 on single hospital
# Define your transformations
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create an instance of CustomDataset for testing
test_dataset = CustomDataset('/local/data1/honzh073/local_repository/FL/code/6_model_test/hospital45.csv', transform=test_transform)


# For test dataset
test_NFF_count = sum(1 for _, label in test_dataset if label == 0)  # 0 NFF
test_AFF_count = sum(1 for _, label in test_dataset if label == 1)  # 1 AFF
print(f"AFF: {test_AFF_count}, ratio: {test_AFF_count / (test_AFF_count + test_NFF_count):.2f}")
print(f"NFF: {test_NFF_count}, ratio: {test_NFF_count / (test_AFF_count + test_NFF_count):.2f}")

test_model(model, test_dataset=test_dataset, batch_size=16)


In [None]:
# test dataset
test_dataset = CustomDataset('/local/data1/honzh073/local_repository/FL/code/6_model_test/hospital55.csv', transform=test_transform)

# For test dataset
test_NFF_count = sum(1 for _, label in test_dataset if label == 0)  # 0 NFF
test_AFF_count = sum(1 for _, label in test_dataset if label == 1)  # 1 AFF
print(f"AFF: {test_AFF_count}, ratio: {test_AFF_count / (test_AFF_count + test_NFF_count):.2f}")
print(f"NFF: {test_NFF_count}, ratio: {test_NFF_count / (test_AFF_count + test_NFF_count):.2f}")

test_model(model, test_dataset=test_dataset, batch_size=16)
