In [13]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [14]:
# Install required packages
!pip install wandb torch torchvision pandas numpy matplotlib seaborn

# Set up Kaggle API
!pip install kaggle

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [15]:
# Upload your kaggle.json to Colab and run:
!mkdir -p ~/.kaggle
!cp /content/drive/MyDrive/ColabNotebooks/kaggle_API_credentials/kaggle.json ~/.kaggle/kaggle.json
! chmod 600 ~/.kaggle/kaggle.json

In [16]:
# Download the dataset
!kaggle competitions download -c challenges-in-representation-learning-facial-expression-recognition-challenge
!unzip -q challenges-in-representation-learning-facial-expression-recognition-challenge.zip

Downloading challenges-in-representation-learning-facial-expression-recognition-challenge.zip to /content
 91% 259M/285M [00:00<00:00, 469MB/s]
100% 285M/285M [00:00<00:00, 476MB/s]


In [17]:
import wandb
import torch
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Login to wandb
wandb.login()

# Init run
run = wandb.init(project="facial-expression-recognition")

# Download artifacts
cnn_artifact = run.use_artifact('ellekvirikashvili-free-university-of-tbilisi-/facial-expression-recognition/run-30aso492-history:v0')
cnn_dir = cnn_artifact.download()

vit_artifact = run.use_artifact('ellekvirikashvili-free-university-of-tbilisi-/facial-expression-recognition/vit-fer2013-final-model:v0')
vit_dir = vit_artifact.download()




[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [18]:
class TestDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.data = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        pixels = self.data.iloc[idx]['pixels']
        image = np.array(pixels.split(), dtype='uint8')
        image = image.reshape(48, 48, 1).astype('float32') / 255.0
        image = np.repeat(image, 3, axis=-1)

        if self.transform:
            image = self.transform(image)

        return image


In [19]:
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),  # for ViT — you can skip or change for CNN
])

test_dataset = TestDataset(test_df, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


In [None]:
# Load your models (adjust based on how you saved them)
cnn_model = torch.load(f"{cnn_dir}/model.pth")  # or however you saved it
vit_model = torch.load(f"{vit_dir}/model.pth")

# Set to evaluation mode
cnn_model.eval()
vit_model.eval()

In [None]:
def ensemble_predict(cnn_model, vit_model, x, weights=[0.5, 0.5]):
    """
    Ensemble prediction combining CNN and ViT models
    """
    with torch.no_grad():
        # Get predictions from both models
        cnn_pred = torch.softmax(cnn_model(x), dim=1)
        vit_pred = torch.softmax(vit_model(x), dim=1)

        # Weighted average
        ensemble_pred = weights[0] * cnn_pred + weights[1] * vit_pred

    return ensemble_pred, cnn_pred, vit_pred

In [None]:
from sklearn.metrics import accuracy_score
import numpy as np

def find_optimal_weights(cnn_preds, vit_preds, true_labels):
    """Find optimal weights for ensemble"""
    best_acc = 0
    best_weights = [0.5, 0.5]

    for w1 in np.arange(0.1, 1.0, 0.1):
        w2 = 1 - w1
        ensemble_pred = w1 * cnn_preds + w2 * vit_preds
        pred_labels = torch.argmax(ensemble_pred, dim=1)
        acc = accuracy_score(true_labels, pred_labels)

        if acc > best_acc:
            best_acc = acc
            best_weights = [w1, w2]

    return best_weights, best_acc

In [1]:
import wandb
import torch
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
from torch.utils.data import DataLoader

class FERensemble:
    def __init__(self, project_name="facial-expression-recognition"):
        self.project_name = project_name
        self.cnn_model = None
        self.vit_model = None
        self.weights = [0.5, 0.5]  # Default equal weights

    def download_models(self):
        """Download both models from W&B"""
        print("Downloading models from W&B...")

        # Initialize wandb
        run = wandb.init(project=self.project_name)

        # Download CNN model
        print("Downloading CNN model...")
        cnn_artifact = run.use_artifact(
            'ellekvirikashvili-free-university-of-tbilisi-/facial-expression-recognition/run-30aso492-history:v0'
        )
        cnn_dir = cnn_artifact.download()

        # Download ViT model
        print("Downloading ViT model...")
        vit_artifact = run.use_artifact(
            'ellekvirikashvili-free-university-of-tbilisi-/facial-expression-recognition/vit-fer2013-final-model:v0',
            type='model'
        )
        vit_dir = vit_artifact.download()

        return cnn_dir, vit_dir

    def load_models(self, cnn_path, vit_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
        """Load the downloaded models"""
        print(f"Loading models on {device}...")

        # Load CNN model
        self.cnn_model = torch.load(cnn_path, map_location=device)
        self.cnn_model.eval()

        # Load ViT model
        self.vit_model = torch.load(vit_path, map_location=device)
        self.vit_model.eval()

        print("Models loaded successfully!")

    def predict_single(self, x, return_individual=False):
        """Make prediction on a single batch"""
        with torch.no_grad():
            # Get predictions from both models
            cnn_logits = self.cnn_model(x)
            vit_logits = self.vit_model(x)

            # Apply softmax to get probabilities
            cnn_probs = F.softmax(cnn_logits, dim=1)
            vit_probs = F.softmax(vit_logits, dim=1)

            # Ensemble prediction (weighted average)
            ensemble_probs = self.weights[0] * cnn_probs + self.weights[1] * vit_probs

            if return_individual:
                return ensemble_probs, cnn_probs, vit_probs
            return ensemble_probs

    def evaluate_on_dataset(self, dataloader, device='cuda' if torch.cuda.is_available() else 'cpu'):
        """Evaluate ensemble on a dataset"""
        all_ensemble_preds = []
        all_cnn_preds = []
        all_vit_preds = []
        all_labels = []

        print("Evaluating on dataset...")

        for batch_idx, (data, labels) in enumerate(dataloader):
            data, labels = data.to(device), labels.to(device)

            # Get predictions
            ensemble_probs, cnn_probs, vit_probs = self.predict_single(data, return_individual=True)

            # Store predictions
            all_ensemble_preds.append(ensemble_probs.cpu())
            all_cnn_preds.append(cnn_probs.cpu())
            all_vit_preds.append(vit_probs.cpu())
            all_labels.append(labels.cpu())

            if batch_idx % 50 == 0:
                print(f"Processed {batch_idx} batches...")

        # Concatenate all predictions
        all_ensemble_preds = torch.cat(all_ensemble_preds, dim=0)
        all_cnn_preds = torch.cat(all_cnn_preds, dim=0)
        all_vit_preds = torch.cat(all_vit_preds, dim=0)
        all_labels = torch.cat(all_labels, dim=0)

        return all_ensemble_preds, all_cnn_preds, all_vit_preds, all_labels

    def optimize_weights(self, val_dataloader, device='cuda' if torch.cuda.is_available() else 'cpu'):
        """Find optimal ensemble weights using validation set"""
        print("Finding optimal ensemble weights...")

        # Get predictions on validation set
        _, cnn_preds, vit_preds, true_labels = self.evaluate_on_dataset(val_dataloader, device)

        best_acc = 0
        best_weights = [0.5, 0.5]

        # Grid search for optimal weights
        for w1 in np.arange(0.1, 1.0, 0.1):
            w2 = 1 - w1

            # Calculate ensemble predictions with these weights
            ensemble_pred = w1 * cnn_preds + w2 * vit_preds
            pred_labels = torch.argmax(ensemble_pred, dim=1)

            # Calculate accuracy
            acc = accuracy_score(true_labels.numpy(), pred_labels.numpy())

            if acc > best_acc:
                best_acc = acc
                best_weights = [w1, w2]

        self.weights = best_weights
        print(f"Optimal weights found: CNN={best_weights[0]:.2f}, ViT={best_weights[1]:.2f}")
        print(f"Best validation accuracy: {best_acc:.4f}")

        return best_weights, best_acc

    def get_metrics(self, predictions, true_labels, class_names=None):
        """Calculate comprehensive metrics"""
        pred_labels = torch.argmax(predictions, dim=1)

        # Basic accuracy
        accuracy = accuracy_score(true_labels.numpy(), pred_labels.numpy())

        # Classification report
        if class_names is None:
            class_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']

        report = classification_report(
            true_labels.numpy(),
            pred_labels.numpy(),
            target_names=class_names,
            output_dict=True
        )

        return accuracy, report

    def compare_models(self, test_dataloader, device='cuda' if torch.cuda.is_available() else 'cpu'):
        """Compare individual models vs ensemble"""
        print("Comparing model performances...")

        ensemble_preds, cnn_preds, vit_preds, true_labels = self.evaluate_on_dataset(test_dataloader, device)

        # Calculate accuracies
        ensemble_acc, ensemble_report = self.get_metrics(ensemble_preds, true_labels)
        cnn_acc, cnn_report = self.get_metrics(cnn_preds, true_labels)
        vit_acc, vit_report = self.get_metrics(vit_preds, true_labels)

        print("\n" + "="*50)
        print("MODEL COMPARISON RESULTS")
        print("="*50)
        print(f"CNN Accuracy:      {cnn_acc:.4f}")
        print(f"ViT Accuracy:      {vit_acc:.4f}")
        print(f"Ensemble Accuracy: {ensemble_acc:.4f}")
        print(f"Improvement:       {ensemble_acc - max(cnn_acc, vit_acc):.4f}")
        print("="*50)

        # Log to W&B
        wandb.log({
            "cnn_accuracy": cnn_acc,
            "vit_accuracy": vit_acc,
            "ensemble_accuracy": ensemble_acc,
            "ensemble_improvement": ensemble_acc - max(cnn_acc, vit_acc),
            "optimal_cnn_weight": self.weights[0],
            "optimal_vit_weight": self.weights[1]
        })

        return {
            'ensemble': {'accuracy': ensemble_acc, 'report': ensemble_report},
            'cnn': {'accuracy': cnn_acc, 'report': cnn_report},
            'vit': {'accuracy': vit_acc, 'report': vit_report}
        }

# Example usage
def main():
    # Initialize ensemble
    ensemble = FERensemble()

    # Download models
    cnn_dir, vit_dir = ensemble.download_models()

    # Load models (adjust paths based on your saved model structure)
    cnn_model_path = f"{cnn_dir}/model.pth"  # or wherever your .pth file is
    vit_model_path = f"{vit_dir}/model.pth"

    ensemble.load_models(cnn_model_path, vit_model_path)

    # Assuming you have your dataloaders ready
    # val_dataloader = your_validation_dataloader
    # test_dataloader = your_test_dataloader

    # Optimize ensemble weights on validation set
    # ensemble.optimize_weights(val_dataloader)

    # Compare models on test set
    # results = ensemble.compare_models(test_dataloader)

    print("Ensemble setup complete!")

if __name__ == "__main__":
    main()

Downloading models from W&B...


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mellekvirikashvili[0m ([33mellekvirikashvili-free-university-of-tbilisi-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Downloading CNN model...


[34m[1mwandb[0m:   1 of 1 files downloaded.  


Downloading ViT model...


[34m[1mwandb[0m:   1 of 1 files downloaded.  


Loading models on cpu...


FileNotFoundError: [Errno 2] No such file or directory: '/content/artifacts/run-30aso492-history:v0/model.pth'

In [2]:
import wandb
run = wandb.init()
artifact = run.use_artifact('ellekvirikashvili-free-university-of-tbilisi-/facial-expression-recognition/vit-fer2013-final-model:v0', type='model')
artifact_dir = artifact.download()

[34m[1mwandb[0m: Currently logged in as: [33mellekvirikashvili[0m ([33mellekvirikashvili-free-university-of-tbilisi-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [3]:
import torch
from timm import create_model

# Load the full checkpoint dictionary
checkpoint = torch.load(f"{artifact_dir}/final_vit_model.pth")

# Recreate the architecture
model = create_model('mobilevit_xxs', pretrained=False, num_classes=7)

# Load the weights from the 'model_state_dict' key
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("Checkpoint loaded successfully ✅")


RuntimeError: Error(s) in loading state_dict for ByobNet:
	Missing key(s) in state_dict: "stem.conv.weight", "stem.bn.weight", "stem.bn.bias", "stem.bn.running_mean", "stem.bn.running_var", "stages.0.0.conv1_1x1.conv.weight", "stages.0.0.conv1_1x1.bn.weight", "stages.0.0.conv1_1x1.bn.bias", "stages.0.0.conv1_1x1.bn.running_mean", "stages.0.0.conv1_1x1.bn.running_var", "stages.0.0.conv2_kxk.conv.weight", "stages.0.0.conv2_kxk.bn.weight", "stages.0.0.conv2_kxk.bn.bias", "stages.0.0.conv2_kxk.bn.running_mean", "stages.0.0.conv2_kxk.bn.running_var", "stages.0.0.conv3_1x1.conv.weight", "stages.0.0.conv3_1x1.bn.weight", "stages.0.0.conv3_1x1.bn.bias", "stages.0.0.conv3_1x1.bn.running_mean", "stages.0.0.conv3_1x1.bn.running_var", "stages.1.0.conv1_1x1.conv.weight", "stages.1.0.conv1_1x1.bn.weight", "stages.1.0.conv1_1x1.bn.bias", "stages.1.0.conv1_1x1.bn.running_mean", "stages.1.0.conv1_1x1.bn.running_var", "stages.1.0.conv2_kxk.conv.weight", "stages.1.0.conv2_kxk.bn.weight", "stages.1.0.conv2_kxk.bn.bias", "stages.1.0.conv2_kxk.bn.running_mean", "stages.1.0.conv2_kxk.bn.running_var", "stages.1.0.conv3_1x1.conv.weight", "stages.1.0.conv3_1x1.bn.weight", "stages.1.0.conv3_1x1.bn.bias", "stages.1.0.conv3_1x1.bn.running_mean", "stages.1.0.conv3_1x1.bn.running_var", "stages.1.1.conv1_1x1.conv.weight", "stages.1.1.conv1_1x1.bn.weight", "stages.1.1.conv1_1x1.bn.bias", "stages.1.1.conv1_1x1.bn.running_mean", "stages.1.1.conv1_1x1.bn.running_var", "stages.1.1.conv2_kxk.conv.weight", "stages.1.1.conv2_kxk.bn.weight", "stages.1.1.conv2_kxk.bn.bias", "stages.1.1.conv2_kxk.bn.running_mean", "stages.1.1.conv2_kxk.bn.running_var", "stages.1.1.conv3_1x1.conv.weight", "stages.1.1.conv3_1x1.bn.weight", "stages.1.1.conv3_1x1.bn.bias", "stages.1.1.conv3_1x1.bn.running_mean", "stages.1.1.conv3_1x1.bn.running_var", "stages.1.2.conv1_1x1.conv.weight", "stages.1.2.conv1_1x1.bn.weight", "stages.1.2.conv1_1x1.bn.bias", "stages.1.2.conv1_1x1.bn.running_mean", "stages.1.2.conv1_1x1.bn.running_var", "stages.1.2.conv2_kxk.conv.weight", "stages.1.2.conv2_kxk.bn.weight", "stages.1.2.conv2_kxk.bn.bias", "stages.1.2.conv2_kxk.bn.running_mean", "stages.1.2.conv2_kxk.bn.running_var", "stages.1.2.conv3_1x1.conv.weight", "stages.1.2.conv3_1x1.bn.weight", "stages.1.2.conv3_1x1.bn.bias", "stages.1.2.conv3_1x1.bn.running_mean", "stages.1.2.conv3_1x1.bn.running_var", "stages.2.0.conv1_1x1.conv.weight", "stages.2.0.conv1_1x1.bn.weight", "stages.2.0.conv1_1x1.bn.bias", "stages.2.0.conv1_1x1.bn.running_mean", "stages.2.0.conv1_1x1.bn.running_var", "stages.2.0.conv2_kxk.conv.weight", "stages.2.0.conv2_kxk.bn.weight", "stages.2.0.conv2_kxk.bn.bias", "stages.2.0.conv2_kxk.bn.running_mean", "stages.2.0.conv2_kxk.bn.running_var", "stages.2.0.conv3_1x1.conv.weight", "stages.2.0.conv3_1x1.bn.weight", "stages.2.0.conv3_1x1.bn.bias", "stages.2.0.conv3_1x1.bn.running_mean", "stages.2.0.conv3_1x1.bn.running_var", "stages.2.1.conv_kxk.conv.weight", "stages.2.1.conv_kxk.bn.weight", "stages.2.1.conv_kxk.bn.bias", "stages.2.1.conv_kxk.bn.running_mean", "stages.2.1.conv_kxk.bn.running_var", "stages.2.1.conv_1x1.weight", "stages.2.1.transformer.0.norm1.weight", "stages.2.1.transformer.0.norm1.bias", "stages.2.1.transformer.0.attn.qkv.weight", "stages.2.1.transformer.0.attn.qkv.bias", "stages.2.1.transformer.0.attn.proj.weight", "stages.2.1.transformer.0.attn.proj.bias", "stages.2.1.transformer.0.norm2.weight", "stages.2.1.transformer.0.norm2.bias", "stages.2.1.transformer.0.mlp.fc1.weight", "stages.2.1.transformer.0.mlp.fc1.bias", "stages.2.1.transformer.0.mlp.fc2.weight", "stages.2.1.transformer.0.mlp.fc2.bias", "stages.2.1.transformer.1.norm1.weight", "stages.2.1.transformer.1.norm1.bias", "stages.2.1.transformer.1.attn.qkv.weight", "stages.2.1.transformer.1.attn.qkv.bias", "stages.2.1.transformer.1.attn.proj.weight", "stages.2.1.transformer.1.attn.proj.bias", "stages.2.1.transformer.1.norm2.weight", "stages.2.1.transformer.1.norm2.bias", "stages.2.1.transformer.1.mlp.fc1.weight", "stages.2.1.transformer.1.mlp.fc1.bias", "stages.2.1.transformer.1.mlp.fc2.weight", "stages.2.1.transformer.1.mlp.fc2.bias", "stages.2.1.norm.weight", "stages.2.1.norm.bias", "stages.2.1.conv_proj.conv.weight", "stages.2.1.conv_proj.bn.weight", "stages.2.1.conv_proj.bn.bias", "stages.2.1.conv_proj.bn.running_mean", "stages.2.1.conv_proj.bn.running_var", "stages.2.1.conv_fusion.conv.weight", "stages.2.1.conv_fusion.bn.weight", "stages.2.1.conv_fusion.bn.bias", "stages.2.1.conv_fusion.bn.running_mean", "stages.2.1.conv_fusion.bn.running_var", "stages.3.0.conv1_1x1.conv.weight", "stages.3.0.conv1_1x1.bn.weight", "stages.3.0.conv1_1x1.bn.bias", "stages.3.0.conv1_1x1.bn.running_mean", "stages.3.0.conv1_1x1.bn.running_var", "stages.3.0.conv2_kxk.conv.weight", "stages.3.0.conv2_kxk.bn.weight", "stages.3.0.conv2_kxk.bn.bias", "stages.3.0.conv2_kxk.bn.running_mean", "stages.3.0.conv2_kxk.bn.running_var", "stages.3.0.conv3_1x1.conv.weight", "stages.3.0.conv3_1x1.bn.weight", "stages.3.0.conv3_1x1.bn.bias", "stages.3.0.conv3_1x1.bn.running_mean", "stages.3.0.conv3_1x1.bn.running_var", "stages.3.1.conv_kxk.conv.weight", "stages.3.1.conv_kxk.bn.weight", "stages.3.1.conv_kxk.bn.bias", "stages.3.1.conv_kxk.bn.running_mean", "stages.3.1.conv_kxk.bn.running_var", "stages.3.1.conv_1x1.weight", "stages.3.1.transformer.0.norm1.weight", "stages.3.1.transformer.0.norm1.bias", "stages.3.1.transformer.0.attn.qkv.weight", "stages.3.1.transformer.0.attn.qkv.bias", "stages.3.1.transformer.0.attn.proj.weight", "stages.3.1.transformer.0.attn.proj.bias", "stages.3.1.transformer.0.norm2.weight", "stages.3.1.transformer.0.norm2.bias", "stages.3.1.transformer.0.mlp.fc1.weight", "stages.3.1.transformer.0.mlp.fc1.bias", "stages.3.1.transformer.0.mlp.fc2.weight", "stages.3.1.transformer.0.mlp.fc2.bias", "stages.3.1.transformer.1.norm1.weight", "stages.3.1.transformer.1.norm1.bias", "stages.3.1.transformer.1.attn.qkv.weight", "stages.3.1.transformer.1.attn.qkv.bias", "stages.3.1.transformer.1.attn.proj.weight", "stages.3.1.transformer.1.attn.proj.bias", "stages.3.1.transformer.1.norm2.weight", "stages.3.1.transformer.1.norm2.bias", "stages.3.1.transformer.1.mlp.fc1.weight", "stages.3.1.transformer.1.mlp.fc1.bias", "stages.3.1.transformer.1.mlp.fc2.weight", "stages.3.1.transformer.1.mlp.fc2.bias", "stages.3.1.transformer.2.norm1.weight", "stages.3.1.transformer.2.norm1.bias", "stages.3.1.transformer.2.attn.qkv.weight", "stages.3.1.transformer.2.attn.qkv.bias", "stages.3.1.transformer.2.attn.proj.weight", "stages.3.1.transformer.2.attn.proj.bias", "stages.3.1.transformer.2.norm2.weight", "stages.3.1.transformer.2.norm2.bias", "stages.3.1.transformer.2.mlp.fc1.weight", "stages.3.1.transformer.2.mlp.fc1.bias", "stages.3.1.transformer.2.mlp.fc2.weight", "stages.3.1.transformer.2.mlp.fc2.bias", "stages.3.1.transformer.3.norm1.weight", "stages.3.1.transformer.3.norm1.bias", "stages.3.1.transformer.3.attn.qkv.weight", "stages.3.1.transformer.3.attn.qkv.bias", "stages.3.1.transformer.3.attn.proj.weight", "stages.3.1.transformer.3.attn.proj.bias", "stages.3.1.transformer.3.norm2.weight", "stages.3.1.transformer.3.norm2.bias", "stages.3.1.transformer.3.mlp.fc1.weight", "stages.3.1.transformer.3.mlp.fc1.bias", "stages.3.1.transformer.3.mlp.fc2.weight", "stages.3.1.transformer.3.mlp.fc2.bias", "stages.3.1.norm.weight", "stages.3.1.norm.bias", "stages.3.1.conv_proj.conv.weight", "stages.3.1.conv_proj.bn.weight", "stages.3.1.conv_proj.bn.bias", "stages.3.1.conv_proj.bn.running_mean", "stages.3.1.conv_proj.bn.running_var", "stages.3.1.conv_fusion.conv.weight", "stages.3.1.conv_fusion.bn.weight", "stages.3.1.conv_fusion.bn.bias", "stages.3.1.conv_fusion.bn.running_mean", "stages.3.1.conv_fusion.bn.running_var", "stages.4.0.conv1_1x1.conv.weight", "stages.4.0.conv1_1x1.bn.weight", "stages.4.0.conv1_1x1.bn.bias", "stages.4.0.conv1_1x1.bn.running_mean", "stages.4.0.conv1_1x1.bn.running_var", "stages.4.0.conv2_kxk.conv.weight", "stages.4.0.conv2_kxk.bn.weight", "stages.4.0.conv2_kxk.bn.bias", "stages.4.0.conv2_kxk.bn.running_mean", "stages.4.0.conv2_kxk.bn.running_var", "stages.4.0.conv3_1x1.conv.weight", "stages.4.0.conv3_1x1.bn.weight", "stages.4.0.conv3_1x1.bn.bias", "stages.4.0.conv3_1x1.bn.running_mean", "stages.4.0.conv3_1x1.bn.running_var", "stages.4.1.conv_kxk.conv.weight", "stages.4.1.conv_kxk.bn.weight", "stages.4.1.conv_kxk.bn.bias", "stages.4.1.conv_kxk.bn.running_mean", "stages.4.1.conv_kxk.bn.running_var", "stages.4.1.conv_1x1.weight", "stages.4.1.transformer.0.norm1.weight", "stages.4.1.transformer.0.norm1.bias", "stages.4.1.transformer.0.attn.qkv.weight", "stages.4.1.transformer.0.attn.qkv.bias", "stages.4.1.transformer.0.attn.proj.weight", "stages.4.1.transformer.0.attn.proj.bias", "stages.4.1.transformer.0.norm2.weight", "stages.4.1.transformer.0.norm2.bias", "stages.4.1.transformer.0.mlp.fc1.weight", "stages.4.1.transformer.0.mlp.fc1.bias", "stages.4.1.transformer.0.mlp.fc2.weight", "stages.4.1.transformer.0.mlp.fc2.bias", "stages.4.1.transformer.1.norm1.weight", "stages.4.1.transformer.1.norm1.bias", "stages.4.1.transformer.1.attn.qkv.weight", "stages.4.1.transformer.1.attn.qkv.bias", "stages.4.1.transformer.1.attn.proj.weight", "stages.4.1.transformer.1.attn.proj.bias", "stages.4.1.transformer.1.norm2.weight", "stages.4.1.transformer.1.norm2.bias", "stages.4.1.transformer.1.mlp.fc1.weight", "stages.4.1.transformer.1.mlp.fc1.bias", "stages.4.1.transformer.1.mlp.fc2.weight", "stages.4.1.transformer.1.mlp.fc2.bias", "stages.4.1.transformer.2.norm1.weight", "stages.4.1.transformer.2.norm1.bias", "stages.4.1.transformer.2.attn.qkv.weight", "stages.4.1.transformer.2.attn.qkv.bias", "stages.4.1.transformer.2.attn.proj.weight", "stages.4.1.transformer.2.attn.proj.bias", "stages.4.1.transformer.2.norm2.weight", "stages.4.1.transformer.2.norm2.bias", "stages.4.1.transformer.2.mlp.fc1.weight", "stages.4.1.transformer.2.mlp.fc1.bias", "stages.4.1.transformer.2.mlp.fc2.weight", "stages.4.1.transformer.2.mlp.fc2.bias", "stages.4.1.norm.weight", "stages.4.1.norm.bias", "stages.4.1.conv_proj.conv.weight", "stages.4.1.conv_proj.bn.weight", "stages.4.1.conv_proj.bn.bias", "stages.4.1.conv_proj.bn.running_mean", "stages.4.1.conv_proj.bn.running_var", "stages.4.1.conv_fusion.conv.weight", "stages.4.1.conv_fusion.bn.weight", "stages.4.1.conv_fusion.bn.bias", "stages.4.1.conv_fusion.bn.running_mean", "stages.4.1.conv_fusion.bn.running_var", "final_conv.conv.weight", "final_conv.bn.weight", "final_conv.bn.bias", "final_conv.bn.running_mean", "final_conv.bn.running_var", "head.fc.weight", "head.fc.bias". 
	Unexpected key(s) in state_dict: "model.stem.conv.weight", "model.stem.bn.weight", "model.stem.bn.bias", "model.stem.bn.running_mean", "model.stem.bn.running_var", "model.stem.bn.num_batches_tracked", "model.stages.0.0.conv1_1x1.conv.weight", "model.stages.0.0.conv1_1x1.bn.weight", "model.stages.0.0.conv1_1x1.bn.bias", "model.stages.0.0.conv1_1x1.bn.running_mean", "model.stages.0.0.conv1_1x1.bn.running_var", "model.stages.0.0.conv1_1x1.bn.num_batches_tracked", "model.stages.0.0.conv2_kxk.conv.weight", "model.stages.0.0.conv2_kxk.bn.weight", "model.stages.0.0.conv2_kxk.bn.bias", "model.stages.0.0.conv2_kxk.bn.running_mean", "model.stages.0.0.conv2_kxk.bn.running_var", "model.stages.0.0.conv2_kxk.bn.num_batches_tracked", "model.stages.0.0.conv3_1x1.conv.weight", "model.stages.0.0.conv3_1x1.bn.weight", "model.stages.0.0.conv3_1x1.bn.bias", "model.stages.0.0.conv3_1x1.bn.running_mean", "model.stages.0.0.conv3_1x1.bn.running_var", "model.stages.0.0.conv3_1x1.bn.num_batches_tracked", "model.stages.1.0.conv1_1x1.conv.weight", "model.stages.1.0.conv1_1x1.bn.weight", "model.stages.1.0.conv1_1x1.bn.bias", "model.stages.1.0.conv1_1x1.bn.running_mean", "model.stages.1.0.conv1_1x1.bn.running_var", "model.stages.1.0.conv1_1x1.bn.num_batches_tracked", "model.stages.1.0.conv2_kxk.conv.weight", "model.stages.1.0.conv2_kxk.bn.weight", "model.stages.1.0.conv2_kxk.bn.bias", "model.stages.1.0.conv2_kxk.bn.running_mean", "model.stages.1.0.conv2_kxk.bn.running_var", "model.stages.1.0.conv2_kxk.bn.num_batches_tracked", "model.stages.1.0.conv3_1x1.conv.weight", "model.stages.1.0.conv3_1x1.bn.weight", "model.stages.1.0.conv3_1x1.bn.bias", "model.stages.1.0.conv3_1x1.bn.running_mean", "model.stages.1.0.conv3_1x1.bn.running_var", "model.stages.1.0.conv3_1x1.bn.num_batches_tracked", "model.stages.1.1.conv1_1x1.conv.weight", "model.stages.1.1.conv1_1x1.bn.weight", "model.stages.1.1.conv1_1x1.bn.bias", "model.stages.1.1.conv1_1x1.bn.running_mean", "model.stages.1.1.conv1_1x1.bn.running_var", "model.stages.1.1.conv1_1x1.bn.num_batches_tracked", "model.stages.1.1.conv2_kxk.conv.weight", "model.stages.1.1.conv2_kxk.bn.weight", "model.stages.1.1.conv2_kxk.bn.bias", "model.stages.1.1.conv2_kxk.bn.running_mean", "model.stages.1.1.conv2_kxk.bn.running_var", "model.stages.1.1.conv2_kxk.bn.num_batches_tracked", "model.stages.1.1.conv3_1x1.conv.weight", "model.stages.1.1.conv3_1x1.bn.weight", "model.stages.1.1.conv3_1x1.bn.bias", "model.stages.1.1.conv3_1x1.bn.running_mean", "model.stages.1.1.conv3_1x1.bn.running_var", "model.stages.1.1.conv3_1x1.bn.num_batches_tracked", "model.stages.1.2.conv1_1x1.conv.weight", "model.stages.1.2.conv1_1x1.bn.weight", "model.stages.1.2.conv1_1x1.bn.bias", "model.stages.1.2.conv1_1x1.bn.running_mean", "model.stages.1.2.conv1_1x1.bn.running_var", "model.stages.1.2.conv1_1x1.bn.num_batches_tracked", "model.stages.1.2.conv2_kxk.conv.weight", "model.stages.1.2.conv2_kxk.bn.weight", "model.stages.1.2.conv2_kxk.bn.bias", "model.stages.1.2.conv2_kxk.bn.running_mean", "model.stages.1.2.conv2_kxk.bn.running_var", "model.stages.1.2.conv2_kxk.bn.num_batches_tracked", "model.stages.1.2.conv3_1x1.conv.weight", "model.stages.1.2.conv3_1x1.bn.weight", "model.stages.1.2.conv3_1x1.bn.bias", "model.stages.1.2.conv3_1x1.bn.running_mean", "model.stages.1.2.conv3_1x1.bn.running_var", "model.stages.1.2.conv3_1x1.bn.num_batches_tracked", "model.stages.2.0.conv1_1x1.conv.weight", "model.stages.2.0.conv1_1x1.bn.weight", "model.stages.2.0.conv1_1x1.bn.bias", "model.stages.2.0.conv1_1x1.bn.running_mean", "model.stages.2.0.conv1_1x1.bn.running_var", "model.stages.2.0.conv1_1x1.bn.num_batches_tracked", "model.stages.2.0.conv2_kxk.conv.weight", "model.stages.2.0.conv2_kxk.bn.weight", "model.stages.2.0.conv2_kxk.bn.bias", "model.stages.2.0.conv2_kxk.bn.running_mean", "model.stages.2.0.conv2_kxk.bn.running_var", "model.stages.2.0.conv2_kxk.bn.num_batches_tracked", "model.stages.2.0.conv3_1x1.conv.weight", "model.stages.2.0.conv3_1x1.bn.weight", "model.stages.2.0.conv3_1x1.bn.bias", "model.stages.2.0.conv3_1x1.bn.running_mean", "model.stages.2.0.conv3_1x1.bn.running_var", "model.stages.2.0.conv3_1x1.bn.num_batches_tracked", "model.stages.2.1.conv_kxk.conv.weight", "model.stages.2.1.conv_kxk.bn.weight", "model.stages.2.1.conv_kxk.bn.bias", "model.stages.2.1.conv_kxk.bn.running_mean", "model.stages.2.1.conv_kxk.bn.running_var", "model.stages.2.1.conv_kxk.bn.num_batches_tracked", "model.stages.2.1.conv_1x1.weight", "model.stages.2.1.transformer.0.norm1.weight", "model.stages.2.1.transformer.0.norm1.bias", "model.stages.2.1.transformer.0.attn.qkv.weight", "model.stages.2.1.transformer.0.attn.qkv.bias", "model.stages.2.1.transformer.0.attn.proj.weight", "model.stages.2.1.transformer.0.attn.proj.bias", "model.stages.2.1.transformer.0.norm2.weight", "model.stages.2.1.transformer.0.norm2.bias", "model.stages.2.1.transformer.0.mlp.fc1.weight", "model.stages.2.1.transformer.0.mlp.fc1.bias", "model.stages.2.1.transformer.0.mlp.fc2.weight", "model.stages.2.1.transformer.0.mlp.fc2.bias", "model.stages.2.1.transformer.1.norm1.weight", "model.stages.2.1.transformer.1.norm1.bias", "model.stages.2.1.transformer.1.attn.qkv.weight", "model.stages.2.1.transformer.1.attn.qkv.bias", "model.stages.2.1.transformer.1.attn.proj.weight", "model.stages.2.1.transformer.1.attn.proj.bias", "model.stages.2.1.transformer.1.norm2.weight", "model.stages.2.1.transformer.1.norm2.bias", "model.stages.2.1.transformer.1.mlp.fc1.weight", "model.stages.2.1.transformer.1.mlp.fc1.bias", "model.stages.2.1.transformer.1.mlp.fc2.weight", "model.stages.2.1.transformer.1.mlp.fc2.bias", "model.stages.2.1.norm.weight", "model.stages.2.1.norm.bias", "model.stages.2.1.conv_proj.conv.weight", "model.stages.2.1.conv_proj.bn.weight", "model.stages.2.1.conv_proj.bn.bias", "model.stages.2.1.conv_proj.bn.running_mean", "model.stages.2.1.conv_proj.bn.running_var", "model.stages.2.1.conv_proj.bn.num_batches_tracked", "model.stages.2.1.conv_fusion.conv.weight", "model.stages.2.1.conv_fusion.bn.weight", "model.stages.2.1.conv_fusion.bn.bias", "model.stages.2.1.conv_fusion.bn.running_mean", "model.stages.2.1.conv_fusion.bn.running_var", "model.stages.2.1.conv_fusion.bn.num_batches_tracked", "model.stages.3.0.conv1_1x1.conv.weight", "model.stages.3.0.conv1_1x1.bn.weight", "model.stages.3.0.conv1_1x1.bn.bias", "model.stages.3.0.conv1_1x1.bn.running_mean", "model.stages.3.0.conv1_1x1.bn.running_var", "model.stages.3.0.conv1_1x1.bn.num_batches_tracked", "model.stages.3.0.conv2_kxk.conv.weight", "model.stages.3.0.conv2_kxk.bn.weight", "model.stages.3.0.conv2_kxk.bn.bias", "model.stages.3.0.conv2_kxk.bn.running_mean", "model.stages.3.0.conv2_kxk.bn.running_var", "model.stages.3.0.conv2_kxk.bn.num_batches_tracked", "model.stages.3.0.conv3_1x1.conv.weight", "model.stages.3.0.conv3_1x1.bn.weight", "model.stages.3.0.conv3_1x1.bn.bias", "model.stages.3.0.conv3_1x1.bn.running_mean", "model.stages.3.0.conv3_1x1.bn.running_var", "model.stages.3.0.conv3_1x1.bn.num_batches_tracked", "model.stages.3.1.conv_kxk.conv.weight", "model.stages.3.1.conv_kxk.bn.weight", "model.stages.3.1.conv_kxk.bn.bias", "model.stages.3.1.conv_kxk.bn.running_mean", "model.stages.3.1.conv_kxk.bn.running_var", "model.stages.3.1.conv_kxk.bn.num_batches_tracked", "model.stages.3.1.conv_1x1.weight", "model.stages.3.1.transformer.0.norm1.weight", "model.stages.3.1.transformer.0.norm1.bias", "model.stages.3.1.transformer.0.attn.qkv.weight", "model.stages.3.1.transformer.0.attn.qkv.bias", "model.stages.3.1.transformer.0.attn.proj.weight", "model.stages.3.1.transformer.0.attn.proj.bias", "model.stages.3.1.transformer.0.norm2.weight", "model.stages.3.1.transformer.0.norm2.bias", "model.stages.3.1.transformer.0.mlp.fc1.weight", "model.stages.3.1.transformer.0.mlp.fc1.bias", "model.stages.3.1.transformer.0.mlp.fc2.weight", "model.stages.3.1.transformer.0.mlp.fc2.bias", "model.stages.3.1.transformer.1.norm1.weight", "model.stages.3.1.transformer.1.norm1.bias", "model.stages.3.1.transformer.1.attn.qkv.weight", "model.stages.3.1.transformer.1.attn.qkv.bias", "model.stages.3.1.transformer.1.attn.proj.weight", "model.stages.3.1.transformer.1.attn.proj.bias", "model.stages.3.1.transformer.1.norm2.weight", "model.stages.3.1.transformer.1.norm2.bias", "model.stages.3.1.transformer.1.mlp.fc1.weight", "model.stages.3.1.transformer.1.mlp.fc1.bias", "model.stages.3.1.transformer.1.mlp.fc2.weight", "model.stages.3.1.transformer.1.mlp.fc2.bias", "model.stages.3.1.transformer.2.norm1.weight", "model.stages.3.1.transformer.2.norm1.bias", "model.stages.3.1.transformer.2.attn.qkv.weight", "model.stages.3.1.transformer.2.attn.qkv.bias", "model.stages.3.1.transformer.2.attn.proj.weight", "model.stages.3.1.transformer.2.attn.proj.bias", "model.stages.3.1.transformer.2.norm2.weight", "model.stages.3.1.transformer.2.norm2.bias", "model.stages.3.1.transformer.2.mlp.fc1.weight", "model.stages.3.1.transformer.2.mlp.fc1.bias", "model.stages.3.1.transformer.2.mlp.fc2.weight", "model.stages.3.1.transformer.2.mlp.fc2.bias", "model.stages.3.1.transformer.3.norm1.weight", "model.stages.3.1.transformer.3.norm1.bias", "model.stages.3.1.transformer.3.attn.qkv.weight", "model.stages.3.1.transformer.3.attn.qkv.bias", "model.stages.3.1.transformer.3.attn.proj.weight", "model.stages.3.1.transformer.3.attn.proj.bias", "model.stages.3.1.transformer.3.norm2.weight", "model.stages.3.1.transformer.3.norm2.bias", "model.stages.3.1.transformer.3.mlp.fc1.weight", "model.stages.3.1.transformer.3.mlp.fc1.bias", "model.stages.3.1.transformer.3.mlp.fc2.weight", "model.stages.3.1.transformer.3.mlp.fc2.bias", "model.stages.3.1.norm.weight", "model.stages.3.1.norm.bias", "model.stages.3.1.conv_proj.conv.weight", "model.stages.3.1.conv_proj.bn.weight", "model.stages.3.1.conv_proj.bn.bias", "model.stages.3.1.conv_proj.bn.running_mean", "model.stages.3.1.conv_proj.bn.running_var", "model.stages.3.1.conv_proj.bn.num_batches_tracked", "model.stages.3.1.conv_fusion.conv.weight", "model.stages.3.1.conv_fusion.bn.weight", "model.stages.3.1.conv_fusion.bn.bias", "model.stages.3.1.conv_fusion.bn.running_mean", "model.stages.3.1.conv_fusion.bn.running_var", "model.stages.3.1.conv_fusion.bn.num_batches_tracked", "model.stages.4.0.conv1_1x1.conv.weight", "model.stages.4.0.conv1_1x1.bn.weight", "model.stages.4.0.conv1_1x1.bn.bias", "model.stages.4.0.conv1_1x1.bn.running_mean", "model.stages.4.0.conv1_1x1.bn.running_var", "model.stages.4.0.conv1_1x1.bn.num_batches_tracked", "model.stages.4.0.conv2_kxk.conv.weight", "model.stages.4.0.conv2_kxk.bn.weight", "model.stages.4.0.conv2_kxk.bn.bias", "model.stages.4.0.conv2_kxk.bn.running_mean", "model.stages.4.0.conv2_kxk.bn.running_var", "model.stages.4.0.conv2_kxk.bn.num_batches_tracked", "model.stages.4.0.conv3_1x1.conv.weight", "model.stages.4.0.conv3_1x1.bn.weight", "model.stages.4.0.conv3_1x1.bn.bias", "model.stages.4.0.conv3_1x1.bn.running_mean", "model.stages.4.0.conv3_1x1.bn.running_var", "model.stages.4.0.conv3_1x1.bn.num_batches_tracked", "model.stages.4.1.conv_kxk.conv.weight", "model.stages.4.1.conv_kxk.bn.weight", "model.stages.4.1.conv_kxk.bn.bias", "model.stages.4.1.conv_kxk.bn.running_mean", "model.stages.4.1.conv_kxk.bn.running_var", "model.stages.4.1.conv_kxk.bn.num_batches_tracked", "model.stages.4.1.conv_1x1.weight", "model.stages.4.1.transformer.0.norm1.weight", "model.stages.4.1.transformer.0.norm1.bias", "model.stages.4.1.transformer.0.attn.qkv.weight", "model.stages.4.1.transformer.0.attn.qkv.bias", "model.stages.4.1.transformer.0.attn.proj.weight", "model.stages.4.1.transformer.0.attn.proj.bias", "model.stages.4.1.transformer.0.norm2.weight", "model.stages.4.1.transformer.0.norm2.bias", "model.stages.4.1.transformer.0.mlp.fc1.weight", "model.stages.4.1.transformer.0.mlp.fc1.bias", "model.stages.4.1.transformer.0.mlp.fc2.weight", "model.stages.4.1.transformer.0.mlp.fc2.bias", "model.stages.4.1.transformer.1.norm1.weight", "model.stages.4.1.transformer.1.norm1.bias", "model.stages.4.1.transformer.1.attn.qkv.weight", "model.stages.4.1.transformer.1.attn.qkv.bias", "model.stages.4.1.transformer.1.attn.proj.weight", "model.stages.4.1.transformer.1.attn.proj.bias", "model.stages.4.1.transformer.1.norm2.weight", "model.stages.4.1.transformer.1.norm2.bias", "model.stages.4.1.transformer.1.mlp.fc1.weight", "model.stages.4.1.transformer.1.mlp.fc1.bias", "model.stages.4.1.transformer.1.mlp.fc2.weight", "model.stages.4.1.transformer.1.mlp.fc2.bias", "model.stages.4.1.transformer.2.norm1.weight", "model.stages.4.1.transformer.2.norm1.bias", "model.stages.4.1.transformer.2.attn.qkv.weight", "model.stages.4.1.transformer.2.attn.qkv.bias", "model.stages.4.1.transformer.2.attn.proj.weight", "model.stages.4.1.transformer.2.attn.proj.bias", "model.stages.4.1.transformer.2.norm2.weight", "model.stages.4.1.transformer.2.norm2.bias", "model.stages.4.1.transformer.2.mlp.fc1.weight", "model.stages.4.1.transformer.2.mlp.fc1.bias", "model.stages.4.1.transformer.2.mlp.fc2.weight", "model.stages.4.1.transformer.2.mlp.fc2.bias", "model.stages.4.1.norm.weight", "model.stages.4.1.norm.bias", "model.stages.4.1.conv_proj.conv.weight", "model.stages.4.1.conv_proj.bn.weight", "model.stages.4.1.conv_proj.bn.bias", "model.stages.4.1.conv_proj.bn.running_mean", "model.stages.4.1.conv_proj.bn.running_var", "model.stages.4.1.conv_proj.bn.num_batches_tracked", "model.stages.4.1.conv_fusion.conv.weight", "model.stages.4.1.conv_fusion.bn.weight", "model.stages.4.1.conv_fusion.bn.bias", "model.stages.4.1.conv_fusion.bn.running_mean", "model.stages.4.1.conv_fusion.bn.running_var", "model.stages.4.1.conv_fusion.bn.num_batches_tracked", "model.final_conv.conv.weight", "model.final_conv.bn.weight", "model.final_conv.bn.bias", "model.final_conv.bn.running_mean", "model.final_conv.bn.running_var", "model.final_conv.bn.num_batches_tracked", "model.head.fc.weight", "model.head.fc.bias". 

In [7]:
print(model)


ByobNet(
  (stem): ConvNormAct(
    (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNormAct2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): SiLU(inplace=True)
    )
  )
  (stages): Sequential(
    (0): Sequential(
      (0): BottleneckBlock(
        (shortcut): Identity()
        (conv1_1x1): ConvNormAct(
          (conv): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNormAct2d(
            32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): SiLU(inplace=True)
          )
        )
        (conv2_kxk): ConvNormAct(
          (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn): BatchNormAct2d(
            32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
          

In [4]:
checkpoint = torch.load(f"{artifact_dir}/final_vit_model.pth")
print(checkpoint.keys())


dict_keys(['model_state_dict', 'optimizer_state_dict', 'test_accuracy', 'config'])


In [5]:
state_dict = checkpoint['model_state_dict']


In [6]:
state_dict = checkpoint['model_state_dict']
for key in list(state_dict.keys())[:10]:  # just first 10 keys to peek
    print(key)


model.stem.conv.weight
model.stem.bn.weight
model.stem.bn.bias
model.stem.bn.running_mean
model.stem.bn.running_var
model.stem.bn.num_batches_tracked
model.stages.0.0.conv1_1x1.conv.weight
model.stages.0.0.conv1_1x1.bn.weight
model.stages.0.0.conv1_1x1.bn.bias
model.stages.0.0.conv1_1x1.bn.running_mean


In [7]:
from collections import OrderedDict

# Fix the keys
fixed_state_dict = OrderedDict()
for k, v in state_dict.items():
    new_k = k.replace("model.", "", 1)  # Remove only the first 'model.'
    fixed_state_dict[new_k] = v

# Load it
model.load_state_dict(fixed_state_dict)


<All keys matched successfully>

In [8]:
import pandas as pd

test_df = pd.read_csv('test.csv')

In [13]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class TestDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.data = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        pixels = self.data.iloc[idx]['pixels']
        image = np.array(pixels.split(), dtype='uint8')
        image = image.reshape(48, 48, 1).astype('float32') / 255.0
        image = np.repeat(image, 3, axis=-1)  # Convert to 3 channels for ViT

        if self.transform:
            image = self.transform(image)

        return image  # no label


In [10]:

# Set device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [14]:
from torchvision import transforms
from PIL import Image


transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [15]:
test_dataset = TestDataset(test_df, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [16]:
import numpy as np

model.to(device)
model.eval()
predictions = []

with torch.no_grad():
    for images in test_loader:
        images = images.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        predictions.extend(preds.cpu().numpy())


In [23]:
submission_df = test_df.copy()
submission_df['emotion'] = predictions  # or vit_preds or your own blend
submission_df = submission_df[['emotion']]  # make sure the format matches what's required

submission_df.to_csv("submission.csv", index=False)
