# Solar

In [None]:
#|code-fold: true
#|output: false
from IPython.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))
import matplotlib.ticker as ticker

import numpy as np
import random
from pathlib import Path
from tqdm import tqdm
import torch
torch.manual_seed(10) 
import torch.optim as optim
import torch.nn as nn
import platform
from PIL import Image
from PIL.ImageStat import Stat
import datetime
import matplotlib.pyplot as plt
from matplotlib import cm
from torch.utils.data import DataLoader, Dataset, random_split, WeightedRandomSampler, SubsetRandomSampler
from torchvision.transforms import Compose, ToTensor, Normalize, RandomApply, ColorJitter, ToPILImage
import ipyplot
from torchmetrics import JaccardIndex
from torchmetrics.functional import jaccard_index
from processing import BigImage, prep_data
import datetime
import json
from constants import ROOT, RUNS_FOLDER


from step_by_step import StepByStep, InverseNormalize, load_tensor, get_means_and_stdevs 
from categorize import check_for_missing_files, LABELS, show_image, overlay_two_images
from models import Segnet
plt.style.use('fivethirtyeight')

from evaluate import display_images, evaluate_unlabeled

def get_current_datetime():
    return datetime.datetime.now().strftime('%Y_%m_%d_%H_%M')

%load_ext autoreload
%autoreload 2

# Data processing

In [None]:
title_mapping = {0: 'background', 1: 'commonrack', 2: 'commonpanel', 3: 'denserack', 4: 'densepanel'}
n_classes = len(title_mapping)

jitter = ColorJitter(brightness=(0.2, 1.0), contrast=(0.3, 1.0))  # jitter will change every time it is called
applier = RandomApply(torch.nn.ModuleList([jitter]), p=0.3)
train_loader, val_loader, weights, n_channels, normalizer, unlabeled_tensor_x, unlabeled_tensor_y, val_composer, idx_map = prep_data(n_classes, applier)

# Training

In [None]:
segnet = Segnet(n_channels=n_channels, n_classes=n_classes)
optimizer = optim.Adam(segnet.parameters(), lr=3e-4)

torch.manual_seed(17)
sbs = StepByStep(segnet, optimizer, nn.CrossEntropyLoss(weight=weights))
sbs.set_loaders(train_loader, val_loader)

In [None]:
run_name = 'five_class_segnet_all_pixels_cpu'    
run_folder = Path(RUNS_FOLDER) / (get_current_datetime() + '_' + run_name)
Path(run_folder).mkdir(exist_ok=True, parents=True)

In [None]:
sbs.train(50)

# Evaluation

In [None]:
sbs.get_metric(jaccard_index)
sbs.save_checkpoint(run_folder / (get_current_datetime() + '_' + 'checkpoint_jaccard_index_' + f'{sbs.metric:.3f}' + '.tar'))

In [None]:
val_loader_iter = iter(val_loader)

In [None]:
x_val, y_val = next(val_loader_iter)
y_pred = sbs.predict(x_val, to_numpy=False).argmax(1)
display_images(x_val, y_val, y_pred, normalizer)

In [None]:
evaluate_unlabeled(unlabeled_tensor_x, unlabeled_tensor_y, val_composer, normalizer, sbs, run_folder, idx_map)