In [10]:
from filmModels import *
from dataPreprocess import *
from filmPostProcess import *
import os

In [11]:
bh_interpolated = np.load('components_analysis/bh_observations.npy').transpose()

In [12]:
df = pd.read_csv('components_analysis/df_unique.csv', header=[0])

In [13]:
bh_interpolated.shape

(561, 10)

In [14]:
sources = np.load('components_analysis/sources_n3.npy')
mixes = np.load('components_analysis/mix_n3.npy')

In [15]:
ica_class = np.argmax(mixes, axis=-1)

In [16]:
ica_class.shape

(561,)

In [None]:
data_root = "saga_data"

In [None]:
def train(trainloader, testloader, print_epochs = False, loss_fn = torch.nn.BCELoss()):
    
    model= models.resnet18()
    model.conv1 = nn.Conv2d(n_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model.fc = nn.Linear(in_features=512, out_features=n_components)
    
    model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), weight_decay = L2_param)

    epoch_loss = np.zeros([train_max_epoch, 2])
    for epoch in range(train_max_epoch):  # loop over the dataset multiple times

        model.train()
        running_loss_sum = 0.0
        for i, data in enumerate(trainloader, 0): # loop over each sample
            # get the inputs; data is a list of [inputs, labels]
            image, labels = data['image'].to(device), data['components_strengths'].to(device)

            predicted = model(image)
            
            
#             print(predicted.squeeze().get_device())
#             print('\n')
#             print(labels.get_device())
            
            
            # squeeze: return tensor with all dimensions of size 1 removed
            loss = loss_fn(predicted.squeeze(), labels)
            
            optimizer.zero_grad()

            loss.backward()

            optimizer.step()

            running_loss_sum += loss.item()

        # ----------- get validation loss for current epoch --------------
        model.eval()
        validation_loss_sum = 0.0
        for i, data in enumerate(testloader, 0): # loop over each sample

            image, labels = data['image'].to(device), data['components_strengths'].to(device)

            predicted = model(image)
            
            loss = loss_fn(predicted.squeeze(), labels)

            validation_loss_sum += loss.item()

        # ---------------- print statistics ------------------------

        running_loss = running_loss_sum / len(trainloader)
        validation_loss = validation_loss_sum / len(testloader)
        epoch_loss[epoch, :] =  [running_loss, validation_loss]
        
        if print_epochs:
            print('epoch %2d: running loss: %.5f, validation loss: %.5f' %
                          (epoch + 1, running_loss, validation_loss))
        
        torch.save(model.state_dict(), os.path.join(models_dir, 'epoch-{}.pt'.format(epoch+1)))
    
    if print_epochs:
        print('Finished Training')
        
    return epoch_loss
        
def test(epoch_loss, print_model_epoch = False):
    
    # ------ select model ---------
    ind = np.argmin(epoch_loss[:, 1])
    
    
    model= models.resnet18()
    model.conv1 = nn.Conv2d(n_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model.fc = nn.Linear(in_features=512, out_features=n_components)
    
    model.load_state_dict(torch.load('{}epoch-{}.pt'.format(models_dir, ind+1)))
    
    model.to(device)
    
    if print_model_epoch:
        print("epoch {} model selected".format(ind+1))
    
    # evaluate model on test set
    model.eval()

    with torch.no_grad():
        res_strengths = []
        
        for i, data in enumerate(testloader, 0):
            image, labels = data['image'].to(device), data['components_strengths'].to(device)
            # y_test.append(label.numpy().list())
            # print(label.shape)
            # print(images.shape)

            output = model(image)
            
            res_strengths.extend(output)
            
    return res_strengths

In [None]:
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size

# batchsize can cause error when last leftover batchsize is 1, batchnorm cannot function on 1 sample data
batchsize = 20
while(train_size % batchsize == 1):
    batchsize+=1
print(batchsize)

train_data, test_data = torch.utils.data.random_split(full_dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))

trainloader = DataLoader(train_data, batch_size=batchsize, shuffle=True)
testloader = DataLoader(test_data, batch_size=batchsize, shuffle=True)

In [None]:
import shap
explainer_gen = shap.DeepExplainer(model, train_images)
shap_values = explainer_gen.shap_values(test_images)
shap_pixels = np.mean(np.mean(abs(shap_values), -1), -1)
shap_samples = np.mean(shap_pixels, 1)
shap_channel = np.mean(shap_samples, -1) # averaged over classes