# **torchxai.landscape**

In [1]:
import torchxai as tx

import torch
from torchvision import models, datasets, transforms
import numpy as np

from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

## **datasets**

In [2]:
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)

# take subset for testing 
train_indices = [i for i in range(100)]
dataset = torch.utils.data.Subset(dataset, train_indices)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0)

Files already downloaded and verified


## **models**

In [3]:
device = torch.device("cpu")

model = models.resnet18(pretrained=True)
model = model.to(device)



## **torchxai.utils.extract_features**

In [4]:
features = tx.utils.extract_features(model, dataloader, device, exclude_layers=None)
features.keys()

dict_keys(['', 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer1.0', 'layer1.0.conv1', 'layer1.0.bn1', 'layer1.0.relu', 'layer1.0.conv2', 'layer1.0.bn2', 'layer1.1', 'layer1.1.conv1', 'layer1.1.bn1', 'layer1.1.relu', 'layer1.1.conv2', 'layer1.1.bn2', 'layer2', 'layer2.0', 'layer2.0.conv1', 'layer2.0.bn1', 'layer2.0.relu', 'layer2.0.conv2', 'layer2.0.bn2', 'layer2.0.downsample', 'layer2.0.downsample.0', 'layer2.0.downsample.1', 'layer2.1', 'layer2.1.conv1', 'layer2.1.bn1', 'layer2.1.relu', 'layer2.1.conv2', 'layer2.1.bn2', 'layer3', 'layer3.0', 'layer3.0.conv1', 'layer3.0.bn1', 'layer3.0.relu', 'layer3.0.conv2', 'layer3.0.bn2', 'layer3.0.downsample', 'layer3.0.downsample.0', 'layer3.0.downsample.1', 'layer3.1', 'layer3.1.conv1', 'layer3.1.bn1', 'layer3.1.relu', 'layer3.1.conv2', 'layer3.1.bn2', 'layer4', 'layer4.0', 'layer4.0.conv1', 'layer4.0.bn1', 'layer4.0.relu', 'layer4.0.conv2', 'layer4.0.bn2', 'layer4.0.downsample', 'layer4.0.downsample.0', 'layer4.0.downsample.1', 'layer4.1', 

## **Landscape**

In [5]:
criterion = torch.nn.CrossEntropyLoss()
x, y, loss_surface = tx.landscape.loss_landscape(model, dataloader, device, criterion=criterion, resolution=5, verbose=0)

100%|██████████| 25/25 [00:41<00:00,  1.67s/it]


In [7]:
tx.landscape.visualize(x, y, loss_surface)