In [3]:
############
# Imports #
############

import torch
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import rasterio
import numpy as np
from utils import ImageDataset, TestSet, filter_dataset, imshow_transform
from cnn_classifier import model_4D
from torch.autograd import Variable
from skimage.transform import resize
from skimage.io import imshow
import wandb
import matplotlib.pyplot as plt 
import torch.optim.lr_scheduler as lr_scheduler
import torchmetrics
import json
from pseudomask import Pseudomasks

In [None]:
def model_from_artifact(run_id, artifact):
    # if loading the model from a wandb artifact
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    run = wandb.init(project= 'VGG_CAMs', id= run_id, resume = 'must')
    artifact = run.use_artifact(f'nadjaflechner/VGG_CAMs/model:{artifact}', type='model')
    artifact_dir = artifact.download()
    state_dict = torch.load(f"{artifact_dir}/model.pth")
    model = model_4D()
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    return model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model_from_artifact('3493djxg', artifact='v31')

In [23]:
loss = nn.MSELoss()
model = model_4D()
model.classifier = nn.Sequential(
            nn.Conv2d(2, 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True), 
            nn.Conv2d(2, 2, kernel_size=3, padding=1)
    )

def percent_calc(conv_output, activation_threshold = 1.5):
    activation_mask = torch.where(conv_output> activation_threshold, 1.0, 0.0)
    activated = torch.sum(activation_mask[:,1,:,:], dim=(-1,-2))
    percent_activated = activated / (activation_mask.shape[-1]* activation_mask.shape[-2])
    return percent_activated

batch = torch.rand(2,4,400,400)
outputs = model(batch)
percent_outputs = percent_calc(outputs, activation_threshold = 1.5)


In [24]:
percent_outputs

tensor([76., 81.])

In [15]:
ones = torch.ones(2,2,400,400)
sum = torch.sum(ones[:,1,:,:], dim=(-1,-2))

In [16]:
sum.shape

torch.Size([2])

In [7]:
model.features

Sequential(
  (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (9): ReLU(inplace=True)
  (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (11): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(

In [8]:
model.features += nn.Sequential(
            nn.Conv2d(512, 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True), 
            nn.Conv2d(512, 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True) 
    )

In [9]:
model.features

Sequential(
  (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (9): ReLU(inplace=True)
  (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (11): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(