In [1]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms, datasets
from torch.utils.data import Subset, Dataset
from PIL import Image
import math
import pickle
import time
import io
import logging
from datetime import datetime
from google.cloud import storage
import wandb

### **DattNet section:**

**DattNet Class**

In [2]:
class DattNet(nn.Module):
    def __init__(self, pic_width, m_patterns, input_shape):
        super(DattNet, self).__init__()
        self.m_patterns = m_patterns
        self.ims = pic_width  # must be even number
        self.input_shape = input_shape
        self.batch_size = input_shape[2]
        # self.filter_depth = filter_depth
        # self.kernel_size = kernel_size
        # self.classes = classes
        self.c_i7 = 768

        # Datt Net
        self.fc_layer_1 = nn.Linear(self.m_patterns, 4096)
        self.fc_layer_2 = nn.Linear(4096, 16384)
        self.diconv_layer_3 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding='same', dilation=2),
            nn.MaxPool2d((2, 2)))
        self.blue_block_4 = BlueBlock(16, 26)
        self.blue_block_5 = BlueBlock(26, 31)
        self.blue_block_6 = BlueBlock(31, 33)
        self.blue_block_7 = BlueBlock(33, 34)
        self.blue_block_8 = BlueBlock(34, 35)
        self.blue_block_9 = BlueBlock(35, 36)
        self.dense_block_10 = DenseBlock(36, 36)
        self.semi_red_block_11 = SemiRedAttBlock(36, 36, 36, 36, 36)
        self.semi_red_block_12 = SemiRedAttBlock(36, 35, 36, 35, 36)
        self.red_block_13 = RedAttBlock(36, 34, 36, 34, 36)
        self.red_block_14 = RedAttBlock(36, 33, 36, 33, 36)
        self.red_block_15 = RedAttBlock(36, 31, 36, 31, 36)
        self.red_block_16 = RedAttBlock(36, 26, 36, 26, 36)
        self.layer_17 = nn.Sequential(nn.Conv2d(in_channels=36, out_channels=36, kernel_size=3, padding='same'),
                                      nn.ReLU(inplace=True),
                                      nn.Upsample(scale_factor=2),
                                      nn.Conv2d(in_channels=36, out_channels=36, kernel_size=3, padding='same'),
                                      nn.ReLU(inplace=True))
        self.layer_18 = nn.Sequential(nn.Conv2d(in_channels=36, out_channels=1, kernel_size=1, padding='same'),
                                      nn.ReLU(inplace=True))
        self.nlt_19 = nn.Sequential(nn.BatchNorm2d(1),
                                    nn.Linear(128, 128))

    def forward(self, x):
        x = self.fc_layer_1(x)
        x = self.fc_layer_2(x)
        x = x.reshape(-1, 1, 128, 128)
        x = self.diconv_layer_3(x)
        x4 = self.blue_block_4(x)
        x5 = self.blue_block_5(x4)
        x6 = self.blue_block_6(x5)
        x7 = self.blue_block_7(x6)
        x8 = self.blue_block_8(x7)
        x9 = self.blue_block_9(x8)
        x = self.dense_block_10(x9)
        x = self.semi_red_block_11(x, x9)
        x = self.semi_red_block_12(x, x8)
        x = self.red_block_13(x, x7)
        x = self.red_block_14(x, x6)
        x = self.red_block_15(x, x5)
        x = self.red_block_16(x, x4)
        x = self.layer_17(x)
        x = self.layer_18(x)
        x = self.nlt_19(x)
        return x

    def first_block(self, x):
        batch_size, num_feats = x.shape
        self.int_fc1 = nn.Linear(num_feats, 32 * 32)
        x = F.relu(self.int_fc1(x))
        x = self.dropout(x)

        x = F.relu(self.int_fc2(x))
        x = self.dropout(x)

        x = F.relu(self.conv_block1(x.view(-1, 1, 2 * self.ims, 2 * self.ims)))
        x = self.dropout(x)

        x = F.relu(self.conv_block2(x))
        x = self.dropout(x)
        return x

    def fork_block(self, x):
        # path 1
        x1 = self.res_block(x)

        # path 2
        x2 = self.maxpool2(x)
        x2 = self.res_block(x2)
        x2 = self.upsample(x2)

        # path 3
        x3 = self.maxpool4(x)
        x3 = self.res_block(x3)
        x3 = self.upsample(x3)
        x3 = self.upsample(x3)

        # path 4
        x4 = self.maxpool8(x)
        x4 = self.res_block(x4)
        x4 = self.upsample(x4)
        x4 = self.upsample(x4)
        x4 = self.upsample(x4)

        concat_x = torch.cat((x1, x2, x3, x4), 1)
        return concat_x

    def res_block(self, x):
        """ 4 blue res block, fit to all paths"""
        for _ in range(4):
            y = F.relu(self.conv_res(x))
            f_x = F.relu(self.conv_res(y))
            x = F.relu(x + f_x)
        return x

    def final_block(self, x):
        x = self.maxpool2(x)

        x = F.relu(self.conv_block3(x))
        x = self.dropout(x)

        x = F.relu(self.conv_block4(x))
        x = self.dropout(x)

        x = F.relu(self.conv_block5(x))
        x = self.dropout(x)

        x = F.relu(self.last_layer(x))
        return x

**Block Classes**

In [3]:
class DenseBlock(nn.Module):
    def __init__(self, c_in, c_out):
        super(DenseBlock, self).__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.layer = nn.Sequential(nn.BatchNorm2d(c_in),
                                   nn.ReLU(inplace=True),
                                   nn.Conv2d(in_channels=c_in, out_channels=c_in, kernel_size=5, padding='same', dilation=2))
        self.transition_layers = nn.Sequential(nn.BatchNorm2d(c_in),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(in_channels=c_in, out_channels=c_in, kernel_size=1, padding='same', dilation=2),
                                    nn.Dropout(0.05),
                                    nn.AvgPool2d(2))

    def forward(self, x1):
        x2 = self.layer(x1)

        x3 = self.layer(x1 + x2)

        x4 = self.layer(x1 + x2 + x3)

        x = x1 + x2 + x3 + x4
        return x

In [4]:
class BlueBlock(nn.Module):
    def __init__(self, c_in, c_out):
        super(BlueBlock, self).__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.dense_block = DenseBlock(c_in, c_in)
        self.transition_layers = nn.Sequential(nn.BatchNorm2d(c_in),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=1, dilation=2),
                                    nn.Dropout(0.05),
                                    nn.AvgPool2d(2))

    def forward(self, x):
        x = self.dense_block(x)
        x = self.transition_layers(x)
        return x

In [5]:
class Attention_block(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()
        self.params = {'F_g': F_g, 'F_l': F_l, 'F_int': F_int}
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

In [6]:
class RedAttBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int, c_in, c_out):
        super(RedAttBlock, self).__init__()
        self.params = {'F_g': F_g, 'F_l': F_l, 'F_int': F_int, 'c_in': c_in, 'c_out': c_out}
        self.att_gate = Attention_block(F_g, F_l, F_int)
        self.layers = nn.Sequential(
            nn.Conv2d(c_in, c_in, kernel_size=3, padding='same'),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(c_in, c_out, kernel_size=3, padding='same'),
            nn.ReLU(inplace=True))
        self.dence_block = DenseBlock(c_out, c_out)

    def forward(self, g, x):
        x = self.att_gate(g, x)
        x = self.layers(x)
        x = self.dence_block(x)
        return x

In [7]:
class SemiRedAttBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int, c_in, c_out):
        super(SemiRedAttBlock, self).__init__()
        self.params = {'F_g': F_g, 'F_l': F_l, 'F_int': F_int, 'c_in': c_in, 'c_out': c_out}
        self.att_gate = Attention_block(F_g, F_l, F_int)
        self.layers = nn.Sequential(
            nn.Conv2d(c_in, c_in, kernel_size=1, padding='same'),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(c_in, c_out, kernel_size=1, padding='same'),
            nn.ReLU(inplace=True))
        self.dence_block = DenseBlock(c_out, c_out)

    def forward(self, g, x):
        x = self.att_gate(g, x)
        x = self.layers(x)
        x = self.dence_block(x)
        return x

### **DataFunctions section:**

In [8]:
def get_data(generate, log_path, batch_size, pic_width, prc_patterns, m_patterns, num_train_coco, num_train_mnist,
             num_test_coco, num_test_mnist, n_gray_levels, data_sets):
    """ Load data if the parameters agree with some saved dataset.
        Create the data otherwise by those parameters and saved it. """
    
    train_set_file_name = f"Loaders/train_loader_trNum_C{num_train_coco}_M{num_train_mnist}" + \
                          f"_mp{prc_patterns}_ngl{n_gray_levels}.pickle"
    test_set_file_name = f"Loaders/test_loader_tsNum_C{num_test_coco}_M{num_test_mnist}" + \
                         f"_mp{prc_patterns}_ngl{n_gray_levels}.pickle"
    
    storage_client = storage.Client()    # create a client to interact with Google Cloud Storage
    bucket = storage_client.get_bucket("our_train_test_data")   # set the bucket
    
    # if the loaders exists, don't generate - unless generate is True
    if (bucket.blob(train_set_file_name).exists() and bucket.blob(test_set_file_name).exists()) and (not generate):
        start = time.time()
        print("loading existing data loaders")   # print message for debug
        train_loader, test_loader = load_generated_data(train_set_file_name, test_set_file_name, batch_size, bucket)
        end = time.time()
        data_time_message = f"Loading the data took : {round(end - start)} sec"
        patterns = []
    else:
        start = time.time()
        train_data, test_data = load_data(data_sets, batch_size, pic_width, num_train_coco,
                                          num_train_mnist, num_test_coco, num_test_mnist)
        train_detector_data, test_detector_data, patterns = generate_data(train_data, test_data, pic_width, m_patterns,
                                                                          batch_size, n_gray_levels, patterns='new')
        print(f"finished generating. starting to save the loaders")   # print message for debug
        save_generated_data(train_set_file_name, train_detector_data, test_set_file_name,
                            test_detector_data, bucket)
        
        # prepare data loaders
        train_loader = torch.utils.data.DataLoader(train_detector_data, batch_size=batch_size)
        test_loader = torch.utils.data.DataLoader(test_detector_data, batch_size=batch_size)
        print(f"starting to train :)")   # print message for debug
        
        end = time.time()
        if round(end - start) > 3600:    # if took more than an hour
            data_time_message = f"Generating the data took : {(round(end - start)) / 3600} hours"
        else:
            data_time_message = f"Generating the data took : {(round(end - start)) / 60} mins"

    return train_loader, test_loader, data_time_message, patterns


# TODO: when loading, the training fails for sweeps (for regular runs it works)
def load_generated_data(train_set_file_name, test_set_file_name, batch_size, bucket):
    """ if data loaders exists - load them. """
    in_train_file = io.BytesIO(bucket.blob(train_set_file_name).download_as_bytes())
    in_test_file = io.BytesIO(bucket.blob(test_set_file_name).download_as_bytes())
        
    train_dataset = pickle.load(in_train_file)
    test_dataset = pickle.load(in_test_file)
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, test_loader

**Generate new data loaders**

In [9]:
def load_data(data_sets, batch_size, pic_width, num_train_coco, num_train_mnist, num_test_coco, num_test_mnist):
    train_data_coco, train_data_div2k, train_data_mnist = [], [], []
    test_data_coco, test_data_div2k, test_data_mnist = [], [], []
    if 'coco' in data_sets:
        train_data_coco, test_data_coco = load_data_coco(batch_size, pic_width, num_train_coco, num_test_coco)
    if 'div2k' in data_sets:
        train_data_div2k, test_data_div2k = load_data_div2k(batch_size, pic_width, num_train_coco, num_test_coco)
    if 'mnist' in data_sets:
        train_data_mnist, test_data_mnist = load_data_mnist(batch_size, pic_width, num_train_mnist, num_test_mnist)
    train_data = [train_data_coco, train_data_div2k, train_data_mnist]
    test_data = [test_data_coco, test_data_div2k, test_data_mnist]
    return train_data, test_data


def generate_data(train_data, test_data, pic_width, m_patterns, batch_size, n_gray_levels, patterns):
    """ Get 2 datasets (train & test) of images.
        Return 2 data loaders (train & test) of simulated detector data of the images (GI samples).
        patterns = 'new' will get random light patterns. Else it will use the patterns from the input."""
    if patterns == 'new':
        patterns = define_m_random_patterns(pic_width, m_patterns)
    train_detector_data, test_detector_data = create_detector_data_for_multidatasets(train_data, test_data, patterns,
                                                                                     pic_width, n_gray_levels, batch_size)
    return train_detector_data, test_detector_data, patterns
        

# TODO: fix memory issues in 2nd/3rd runs of a sweep when trying to save the datasets
def save_generated_data(train_set_file_name, train_data, test_set_file_name, test_data, bucket):
    # serialize the train_loader object to a binary file in memory, then write the it to a blob in the bucket
    serialized_train_loader = pickle.dumps(train_data)
    bucket.blob(train_set_file_name).upload_from_string(serialized_train_loader, content_type='application/octet-stream')
    
    serialized_test_loader = pickle.dumps(test_data)
    bucket.blob(test_set_file_name).upload_from_string(serialized_test_loader, content_type='application/octet-stream')

In [10]:
# custom dataset class to load coco images from the bucket
class CustomImageFolder(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        storage_client = storage.Client()
        self.bucket_name = "our_train_test_data"
        self.bucket = storage_client.get_bucket(self.bucket_name)
        
        self.files = self.get_files_from_bucket(self.root)
    
    def get_files_from_bucket(self, directory):
        prefix = f"{directory}/"
        blobs = self.bucket.list_blobs(prefix=prefix)
        files = [blob.name for blob in blobs if not blob.name.endswith('/')]  # filter out directories
        return files
    
    def __getitem__(self, index):
        file_path = self.files[index]
        blob = self.bucket.blob(file_path)
        img = Image.open(io.BytesIO(blob.download_as_bytes())).convert('RGB')
        
        if self.transform is not None:
            img = self.transform(img)
        return img, 0
    
    def __len__(self):
        return len(self.files)
    
    
def load_data_coco(batch_size, pic_width, num_train_samples, num_test_samples):
    """ Load COCO images from the folder.
        Return 2 datasets variables contained images - train and test - with the asked lengths."""

    train_images_path = 'data/coco-2017/all_data/train'
    test_images_path = 'data/coco-2017/all_data/test'
    
    transform = transforms.Compose([transforms.Resize((pic_width, pic_width)),
                                    transforms.ToTensor(),
                                    transforms.Grayscale(num_output_channels=1),
                                    transforms.Normalize((0.5,), (0.5,)),
                                    lambda x: x.float()])

    # load the datasets from the bucket
    train_data = CustomImageFolder(train_images_path, transform=transform)
    test_data = CustomImageFolder(test_images_path, transform=transform)
    
    # slice the data
    train_data = Subset(train_data, np.arange(num_train_samples))
    test_data = Subset(test_data, np.arange(num_test_samples))

    return train_data, test_data

In [11]:
def load_data_mnist(batch_size, pic_width, num_train_mnist, num_test_mnist):
    transform = transforms.Compose([transforms.Resize((pic_width, pic_width)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.1307,), (0.3081,)),
                                    lambda x: x > 0,
                                    lambda x: x.float()])  # convert data to torch.FloatTensor
    
    storage_client = storage.Client()
    bucket = storage_client.get_bucket("our_train_test_data")
    
    train_data = datasets.MNIST(root='data', train=True, download=False, transform=transform)
    test_data = datasets.MNIST(root='data', train=False, download=False, transform=transform)
    
    # slice the data if num is smaller than max
    if num_train_mnist < 60000:
        train_data = Subset(train_data, np.arange(num_train_mnist))
    if num_test_mnist < 10000:
        test_data = Subset(test_data, np.arange(num_test_mnist))

    return train_data, test_data

In [12]:
# this is not used - we don't have div2k images in the bucket
def load_data_div2k(batch_size, pic_width, num_train_samples, num_test_samples):
    """ Load div2k images from the folder.
        Return 2 datasets variables contained images - train and test - with the asked lengths."""

    train_images_path = r'C:\Users\user\Desktop\Projects\DattNet\data\div2k\DIV2K_train_HR'
    test_images_path = r'C:\Users\user\Desktop\Projects\DattNet\data\div2k\DIV2K_valid_HR'

    transform = transforms.Compose([transforms.Resize((pic_width, pic_width)),
                                    transforms.ToTensor(),
                                    transforms.Grayscale(num_output_channels=1),
                                    transforms.Normalize((0.5,), (0.5,)),
                                    lambda x: x.float()])  #

    train_data = datasets.ImageFolder(train_images_path, transform=transform)
    test_data = datasets.ImageFolder(test_images_path, transform=transform)

    # slice the data
    train_data = Subset(train_data, np.arange(num_train_samples))
    test_data = Subset(test_data, np.arange(num_test_samples))

    return train_data, test_data

In [13]:
def define_m_random_patterns(pic_width, m):
    """ define the light patterns"""
    patterns = torch.rand(m, pic_width, pic_width)
    return patterns


def create_detector_data_for_multidatasets(train_data, test_data, patterns, pic_width, n_gray_levels, batch_size):
    """ Get 2 datasets (train & test) of images and light patterns.
        Return 2 tensors (train & test) of couples of [GI_sample, low_gray_image]"""
    train_detector_data, test_detector_data = [], []
    
    for train_set, test_set in zip(train_data, test_data):
        cur_train_detector_data = create_detector_data(train_set, patterns, pic_width, n_gray_levels, batch_size)   
        cur_test_detector_data = create_detector_data(test_set, patterns, pic_width, n_gray_levels, batch_size) 
        
        train_detector_data.extend(cur_train_detector_data)
        test_detector_data.extend(cur_test_detector_data)
    
    return train_detector_data, test_detector_data


# This function runs seperately for each dataset - coco/mnist and for train/test,
# and calculates per batches for efficiency
def create_detector_data(data_set, patterns, pic_width, n_gray_levels, batch_size):
    """ Get a datasets (train/test) of images and light patterns.
        Return a tensor (train/test) of couples of [GI_sample, low_gray_image]"""
    detector_data = []
    for batch_start in range(0, len(data_set), batch_size):
        start = time.time()    # for print message after each batch
        batch_end = min(batch_start + batch_size, len(data_set))
        batch_indices = list(range(batch_start, batch_end))
        batch = Subset(data_set, batch_indices)
        
        batch_output = []
        for sample in batch:
            image = sample[0].view(pic_width, pic_width).to(torch.float32)
            low_gray_image = hist_equalization(image, n_gray_levels)
            detector_output = sample_after_patterns(low_gray_image, patterns).clone().detach()
            batch_output.append([detector_output, torch.flatten(low_gray_image)])

        detector_data.extend(batch_output)   # extend to flatten the batch_output
        print(f"{time.time()-start}:\t\tfinished a batch")   # print message for debug

    return detector_data


def sample_after_patterns(image, patterns):
    """ Simulate the detector. Get one image and return GI output from the patterns."""
    patterns_tensor = patterns.to(image.device)
    detector_output = torch.zeros(len(patterns), device=image.device)

    for i in range(len(patterns)):
        image_after_mask = image * patterns_tensor[i, None, :, :]
        detector_output[i] = torch.sum(image_after_mask)
    
    return detector_output

### **LogFunctions section:**

In [14]:
def print_run_info_to_log(batch_size, pic_width, prc_patterns, n_gray_levels, m_patterns, initial_lr, div_factor_lr,
                          num_dif_lr, n_epochs, num_train_coco, num_train_mnist, num_test_coco, num_test_mnist,
                          lr_vector, epochs_vector, run_name, TRAIN_BY, folder_path='Logs'):
    now = datetime.now()
    dt_string = now.strftime("%d_%m_%Y__%H_%M_%S")
    log_name = f"Log_{dt_string}_prc_{prc_patterns}.log"
    print(f'Name of log file: {log_name}')
    log_path = folder_path +'/'+ log_name
    logging.basicConfig(filename=log_path, format='%(asctime)s %(message)s', filemode='w')
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    logger.info(f"This is a summary of the run '{run_name}':")
    logger.info(f'Batch size for this run: {batch_size}')
    logger.info(f'Size of original image: {pic_width}x{pic_width}')
    logger.info(f'Number of patterns: {m_patterns} which is {prc_patterns}% of {pic_width}^2')
    # logger.info(f'Number of gray levels in output image: {n_gray_levels}')

    if TRAIN_BY == "Vectors":
        logger.info(f"lr_vector {lr_vector}, epochs_vector {epochs_vector}")
    else:
        logger.info(f"Initial lr: {initial_lr}, Division factor: {div_factor_lr},\n\t\t\
                    {num_dif_lr} divisions with {n_epochs} epochs for each")
        
    print_and_log_message(f"Number of samples in train is {num_train_mnist}-mnist and {num_train_coco}-coco", log_path)
    print_and_log_message(f"Number of samples in test is {num_test_mnist}-mnist and {num_test_coco}-coco", log_path)

    logger.info('*************************\n\n')
    
    # This code section is for storing the log in the bucket (and not only locally):
    # storage_client = storage.Client()   # create a client to interact with Google Cloud Storage
    # bucket = storage_client.get_bucket("our_train_test_data")
    # bucket.blob(log_path).upload_from_filename(log_path)   # upload the log file to the bucket
    
    return log_path

In [15]:
def print_and_log_message(message, log_path):
    logging.basicConfig(filename=log_path, format='%(asctime)s %(message)s', filemode='w')
    logger = logging.getLogger()
    logger.setLevel(logging.WARNING)

    if type(message) == str:
        # print(message)
        logger.warning(message)
    else:
        logger.exception(message)

### **Support Functions section:**

In [16]:
def make_folder(net_name, num_samp_coco, num_samp_mnist, batch_size, n_gray_levels, prc_patterns):
    folder_name = f'{net_name}_NumSamp_C{num_samp_coco}_M{num_samp_mnist}_bs_{batch_size}_prc{prc_patterns}'
    print(folder_name)
    folder_path = 'Results/' + folder_name
    
    storage_client = storage.Client()
    bucket = storage_client.get_bucket("our_train_test_data")   # set the bucket
    
    # check if not[there are any files (blobs), in the folder]
    if not any(bucket.list_blobs(prefix=folder_path)):
        empty_blob = bucket.blob(folder_path + '/')  # create an empty blob with the folder prefix
        empty_blob.upload_from_string('')            # upload an empty string as the content for the blob
    return folder_path


def save_img_to_bucket(img_path, img_array, n_gray_levels):
    # initialize the Google Cloud Storage client
    storage_client = storage.Client()
    bucket = storage_client.get_bucket("our_train_test_data")
    
    # convert the image to PIL, and save to a byte buffer
    # normalize the values to 0-240 (not 255 because of rounding)
    img_array = (img_array * n_gray_levels).astype(np.uint8)
    pil_image = Image.fromarray(img_array).convert('L')
    byte_buffer = io.BytesIO()
    pil_image.save(byte_buffer, format='JPEG')
    
    # create a blob object representing the image path in the bucket
    blob = bucket.blob(img_path)

    img_byts = img_array.astype(np.uint8).tobytes()   # convert the image array to bytes
    blob.upload_from_string(byte_buffer.getvalue(), content_type='image/jpeg')
    
    
# this function is not used
def discretize(my_tensor, n_gray_levels):
    """ change scale to n gray levels """
    tensor_0_to_1 = my_tensor / 255
    disc_tensor = (tensor_0_to_1 * n_gray_levels).round()   # max is n_gray_levels
    return disc_tensor

**Histogram Equalization functions**

In [17]:
def hist_equalization(image, n_gray_levels):
    image_0 = image - image.min()                # min value becomes 0
    image_0_255 = image_0 / image_0.max() * 255  # max value becomes 255
    image_unit8 = image_0_255.to(torch.uint8)
    
    # calculate histogram on the GPU
    histogram = torch.histc(image_unit8.float().view(-1), bins=256, min=0, max=255)
    cdf = histogram.cumsum(0)
    cdf = (cdf / cdf[-1]) * 255
    
    # perform histogram equalization
    equalized_image = torch.gather(cdf, 0, image_unit8.view(-1).long())
    equalized_image = equalized_image.view(image_unit8.size())
    
    # convert back to low_gray values and return as tensor, range is now 0-(255//n_gray_level)
    equ_low_gray = torch.div(equalized_image, n_gray_levels, rounding_mode='trunc').to(torch.float32)
    
    return equ_low_gray


def hist_equ_for_tensor(image_tensor, n_gray_levels):
    dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    image_tensor = image_tensor.to(dev)
    equ_reduced_tensor = torch.zeros_like(image_tensor, device=dev)
    
    for i, image in enumerate(image_tensor):
        image = image.to(dev)
        gray_reduced_image = hist_equalization(image[0], n_gray_levels)
        equ_reduced_tensor[i][0] = gray_reduced_image
    equ_reduced_tensor = equ_reduced_tensor.to(torch.float32)
    equ_reduced_tensor = Variable(equ_reduced_tensor, requires_grad=True)
    
    return equ_reduced_tensor


# this function is not used
def display_hist_equ_results(image, equ, n_gray_levels):
    res = np.hstack((image//n_gray_levels, equ//n_gray_levels))  # stacking images side-by-side
    plt.imshow(res)

    plt.hist(np.array(image.flatten())/n_gray_levels, n_gray_levels, [0, n_gray_levels], color='r')
    plt.hist(np.array(equ.flatten())/n_gray_levels, n_gray_levels, [0, n_gray_levels], color='b')
    plt.xlim([0, n_gray_levels])
    plt.legend(('histogram before', 'histogram after'), loc='upper left')
    plt.show()

**PSNR and SSIM calculations**

In [18]:
def PSNR(image1, image2, m, n, max_i=255):
    """ m = n = pic_width = 128,   max_i = n_gray_levels = 30 """
    dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    y = torch.add(image1.to(dev), (-image2).to(dev))
    y2 = torch.pow(y, 2)
    mse = torch.sum(y2) / (m * n)
    psnr = 10 * math.log(max_i/mse, 10)
    return psnr


def SSIM(image1, image2):
    c1, c2 = 0.01, 0.03   # constant for numerical stability
    window_size = 11      # window size for SSIM calculation
    
    image1 = image1.unsqueeze(0)
    image2 = image2.unsqueeze(0)

    mu1 = F.avg_pool2d(image1, window_size, 1, window_size // 2, count_include_pad=False).squeeze(0)
    mu2 = F.avg_pool2d(image2, window_size, 1, window_size // 2, count_include_pad=False).squeeze(0)
    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.avg_pool2d(image1 * image1, window_size, 1, window_size // 2, count_include_pad=False).squeeze(0) - mu1_sq
    sigma2_sq = F.avg_pool2d(image2 * image2, window_size, 1, window_size // 2, count_include_pad=False).squeeze(0) - mu2_sq
    sigma12 = F.avg_pool2d(image1 * image2, window_size, 1, window_size // 2, count_include_pad=False).squeeze(0) - mu1_mu2

    numerator = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
    denominator = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
    ssim_map = numerator / denominator

    ssim = ssim_map.mean().item()
    return ssim

### **Training section**

**Save Outputs**

In [19]:
def save_output_images_vs_original(low_gray_output, output, y_label, pic_width, folder_path, name_sub_folder, n_gray_levels):
    dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    low_gray_output = low_gray_output.to(dev)
    output = output.to(dev)
    y_label = y_label.to(dev)

    in_out_images = zip(low_gray_output.view(-1, pic_width, pic_width),
                        output.view(-1, pic_width, pic_width),
                        y_label.view(-1, pic_width, pic_width))
    
    images_dir = folder_path + '/' + name_sub_folder
    storage_client = storage.Client()
    bucket = storage_client.get_bucket("our_train_test_data")   # set the bucket
    
    if not any(bucket.list_blobs(prefix=images_dir)):
        empty_blob = bucket.blob(images_dir + '/')  # create an empty blob with the folder prefix
        empty_blob.upload_from_string('')           # upload an empty string as the content for the blob        
        
    for i, (low_gray_image, out_image, orig_image) in enumerate(in_out_images):
        # fig, ax = plt.subplots(2)
        # ax[0].imshow(out_image.detach().numpy())
        # ax[1].imshow(orig_image.cpu().detach().numpy())
        save_img_to_bucket(images_dir + f'/train_image_{i}_low_gray_out.jpeg',
                           low_gray_image.detach().cpu().numpy(), n_gray_levels)
        save_img_to_bucket(images_dir + f'/train_image_{i}_out.jpeg',
                           out_image.detach().cpu().numpy(), n_gray_levels)
        save_img_to_bucket(images_dir + f'/train_image_{i}_orig.jpeg',
                           orig_image.cpu().detach().cpu().numpy(), n_gray_levels)

        if i == 0:
            heq_out = wandb.Image(low_gray_image)
        if i > 29:
            break

        return heq_out


# TODO: fix this function to save to bucket instead of locally, then uncomment it everywhere
def save_psnr_and_Loss_functions(loss_func, psnr_func, folder_path):
    with open(folder_path + '/Loss_training', "wb") as output_file:
        pickle.dump(loss_func, output_file)
    with open(folder_path + '/PSNR_training', "wb") as output_file:
        pickle.dump(psnr_func, output_file)


def save_outputs(low_gray_output, output, y_label, pic_width, folder_path, name_sub_folder,
                 loss_func, psnr_func, n_gray_levels):
    heq_out = save_output_images_vs_original(low_gray_output, output, y_label, pic_width,
                                             folder_path, name_sub_folder, n_gray_levels)
    # save_psnr_and_Loss_functions(loss_func, psnr_func, folder_path)

    return heq_out

In [20]:
def print_training_messages(epoch, train_loss, avg_psnr, avg_ssim, start, log_path):
    end = time.time()
    print_and_log_message(f'Epoch: {epoch + 1} \tTraining Loss: {train_loss:.6f}', log_path)
    print_and_log_message(f"Time for epoch {epoch + 1} : {round(end - start)} sec", log_path)
    print_and_log_message(f'Average PSNR for epoch {epoch + 1} on training set is {avg_psnr:.6f}', log_path)
    print_and_log_message(f'Average SSIM for epoch {epoch + 1} on training set is {avg_ssim:.6f}', log_path)
    

def avg_calc(sum_psnr, sum_ssim, train_loss, num_samples):
    avg_psnr = sum_psnr / num_samples
    avg_ssim = sum_ssim / num_samples
    train_loss = train_loss / num_samples
    return avg_psnr, avg_ssim, train_loss

**Train functions**

In [21]:
def train_by_parameters(initial_lr, num_dif_lr, div_factor_lr, train_loader, criterion, batch_size, n_epochs, pic_width,
                        log_path, n_gray_levels, folder_path, model):
    
    storage_client = storage.Client()   # create a client to interact with Google Cloud Storage
    bucket = storage_client.get_bucket("our_train_test_data")
    
    lr_i = initial_lr
    for i in range(num_dif_lr):
        print_and_log_message(f'learning rate: {lr_i}', log_path)
        # bucket.blob(log_path).upload_from_filename(log_path)   # update the log file for each lr, if saved to bucket
        optimizer = torch.optim.Adam(model.parameters(), lr=lr_i)
        model, optimizer = train_net(model, train_loader, criterion, optimizer, batch_size, n_epochs, pic_width,
                                     log_path, n_gray_levels, lr_i, folder_path,  name_sub_folder='train_images')
        lr_i = lr_i / div_factor_lr
    return model


def train_by_vectors(lr_vector, epochs_vector,  train_loader, criterion, batch_size, pic_width,
                     log_path, n_gray_levels, folder_path, model):

    storage_client = storage.Client()
    bucket = storage_client.get_bucket("our_train_test_data")
    
    for lr_i, n_epochs in zip(lr_vector, epochs_vector):
        print_and_log_message(f'learning rate: {lr_i}', log_path)
        # bucket.blob(log_path).upload_from_filename(log_path)   # update the log file for each lr, if saved to bucket
        optimizer = torch.optim.Adam(model.parameters(), lr=lr_i)
        model, optimizer = train_net(model, train_loader, criterion, optimizer, batch_size, n_epochs, pic_width,
                                     log_path, n_gray_levels, lr_i, folder_path,  name_sub_folder='train_images')
    return model


def train_net(model, train_loader, criterion, optimizer, batch_size, n_epochs, pic_width, log_path, n_gray_levels,
              curr_lr, folder_path, name_sub_folder, training=1):
    """ train the network by the model.
        n_epochs - number of times the NN see all the train data """
    model.train()
    wandb.watch(model, criterion, log="all", log_freq=10)

    num_samples = len(train_loader.dataset)
    loss_func, psnr_func = [], []
    
    for epoch in range(n_epochs):
        start = time.time()   # start measuring time for epoch
        train_loss, sum_psnr, sum_ssim = 0.0, 0.0, 0.0
        for x_data, y_label in train_loader:
            low_gray_output, output, optimizer, loss_func, sum_psnr, sum_ssim, psnr_func, train_loss = \
                train_batch(model, x_data, y_label, optimizer, pic_width, loss_func, train_loss,
                            sum_psnr, sum_ssim, psnr_func, n_gray_levels, criterion, training)
        avg_psnr, avg_ssim, train_loss = avg_calc(sum_psnr, sum_ssim, train_loss, num_samples)
        print_training_messages(epoch, train_loss, avg_psnr, avg_ssim, start, log_path)
        heq_out = save_outputs(low_gray_output, output, y_label, pic_width, folder_path, name_sub_folder,
                               loss_func, psnr_func, n_gray_levels)

        # log current values to W&B
        wandb.log({"loss": train_loss, "PSNR": avg_psnr, "SSIM": avg_ssim, "HEQ-output": heq_out})

    return model, optimizer


def train_batch(model, x_data, y_label, optimizer, pic_width, loss_func, train_loss, sum_psnr, sum_ssim, psnr_func,
                n_gray_levels, criterion, training):
    dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    x_data, y_label = x_data.to(dev), y_label.to(dev)
    y_label = y_label.to(torch.float32)
    optimizer.zero_grad()    # clear the gradients of all optimized variables
    output = model(x_data)   # forward pass: compute predictions
    low_gray_output = hist_equ_for_tensor(output, n_gray_levels)
    
    # calculate the loss
    loss = criterion(low_gray_output.view(-1, 1, pic_width, pic_width), y_label.view(-1, 1, pic_width, pic_width))
    
    if training:
        loss.backward()      # backward pass: compute gradient of the loss
        optimizer.step()     # parameter update - perform a single optimization step
    train_loss += loss.item() * x_data.size(0)        # update running training loss
    loss_func.append(loss.item() * x_data.size(0))

    in_out_images = zip(low_gray_output.view(-1, pic_width, pic_width), y_label.view(-1, pic_width, pic_width))
    temp_sum_psnr = 0
    temp_sum_ssim = 0
    for out_image, orig_image in in_out_images:
        temp_sum_psnr += PSNR(out_image, orig_image, pic_width, pic_width, n_gray_levels)
        temp_sum_ssim += SSIM(out_image, orig_image)

    psnr_func.append(temp_sum_psnr / x_data.size(0))  # per batch size in epoch
    sum_psnr += temp_sum_psnr
    sum_ssim += temp_sum_ssim

    return low_gray_output, output, optimizer, loss_func, sum_psnr, sum_ssim, psnr_func, train_loss

### **Testing section**

In [22]:
def print_testing_messages(test_loss, num_samples, avg_psnr, avg_ssim, log_path):
    print_and_log_message(f'Avarge PSNR for {num_samples} images in test set is {avg_psnr}', log_path)
    print_and_log_message(f'Avarge SSIM for {num_samples} images in test set is {avg_ssim}', log_path)
    print_and_log_message('Test Loss: {:.6f}\n'.format(test_loss), log_path)

    
def save_test_outputs(output, y_label, pic_width, folder_path, name_sub_folder, n_gray_levels):
    dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    output = output.to(dev)
    y_label = y_label.to(dev)

    in_out_images = zip(output.view(-1, pic_width, pic_width), y_label.view(-1, pic_width, pic_width))
    images_dir = folder_path + '/' + name_sub_folder
    
    storage_client = storage.Client()   # create a client to interact with Google Cloud Storage
    bucket = storage_client.get_bucket("our_train_test_data")   # set the bucket
    
    if not any(bucket.list_blobs(prefix=images_dir)):
        empty_blob = bucket.blob(images_dir + '/')   # create an empty blob with the folder prefix
        empty_blob.upload_from_string('')            # upload an empty string as the content for the blob
    
    for i, (out_image, orig_image) in enumerate(in_out_images):
        save_img_to_bucket(images_dir + f'/test_image_{i}_out.jpeg', out_image.detach().cpu().numpy(), n_gray_levels)
        save_img_to_bucket(images_dir + f'/test_image_{i}_orig.jpeg', orig_image.detach().cpu().numpy(), n_gray_levels)
        if i > 18:
            break

In [23]:
def test_net(model, test_loader, criterion, batch_size, pic_width, log_path, n_gray_levels, folder_path,
             name_sub_folder='test_images'):
    test_loss, sum_psnr, sum_ssim = 0.0, 0.0, 0.0
    num_samples = len(test_loader.dataset)
    model.eval()
    for x_data, y_label in test_loader:
        output, test_loss, sum_psnr, sum_ssim = test_batch(model, x_data, y_label, pic_width, test_loss, sum_psnr,
                                                           sum_ssim, n_gray_levels, criterion)
    avg_psnr, avg_ssim, train_loss = avg_calc(sum_psnr, sum_ssim, test_loss, num_samples)
    print_testing_messages(test_loss, num_samples, avg_psnr, avg_ssim, log_path)
    save_test_outputs(output, y_label, pic_width, folder_path, name_sub_folder, n_gray_levels)


def test_batch(model, x_data, y_label, pic_width, test_loss, sum_psnr, sum_ssim, n_gray_levels, criterion):
    dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    x_data, y_label = x_data.to(dev), y_label.to(dev).to(torch.float32)
    output = model(x_data)
    # disc_output = discretize(output, n_gray_levels)
    loss = criterion(output.view(-1, 1, pic_width, pic_width),
                     y_label.view(-1, 1, pic_width, pic_width))  # calculate the loss
    test_loss += loss.item() * x_data.size(0)                    # update test loss

    in_out_images = zip(output.view(-1, pic_width, pic_width), y_label.view(-1, pic_width, pic_width))
    for out_image, orig_image in in_out_images:
        sum_psnr += PSNR(out_image, orig_image, pic_width, pic_width, n_gray_levels)
        sum_ssim += SSIM(out_image, orig_image)

    return output, test_loss, sum_psnr, sum_ssim

### **Parameters section**

In [24]:
def get_run_parameters():
    batch_size = 90
    pic_width = 128
    prc_patterns = 60
    n_gray_levels = 30
    m_patterns = (pic_width ** 2) * prc_patterns // 100

    # for train_by_params
    initial_lr = 10 ** -3
    num_dif_lr = 3
    div_factor_lr = 100
    n_epochs = 250   # per learning rate
    
    # for train_by_vectors
    lr_vector = [10 ** -3, 10 ** -4, 10 ** -6]
    epochs_vector = [50, 10, 140]

    input_shape = (pic_width, pic_width, batch_size)
    data_sets = ['mnist', 'coco']   # optional values: 'mnist', 'coco', 'div2k'
    
    # use 20000-50 coco, 5000-700 mnist for sweep
    train_samples_coco, test_samples_coco = (20000, 10) if 'coco' in data_sets else (0, 0)      # maximum 102995 train, 50 test
    train_samples_mnist, test_samples_mnist = (10000, 10) if 'mnist' in data_sets else (0, 0)   # maximum 60000 train, 10000 test
    
    generate = True        # to generate even if loaders exist
    TRAIN_BY = "Params"    # "Params" for train_by_parameters, "Vectors" for train_by_vectors

    return batch_size, pic_width, prc_patterns, n_gray_levels, m_patterns, initial_lr, div_factor_lr, num_dif_lr, \
           n_epochs, train_samples_coco, test_samples_coco, train_samples_mnist, test_samples_mnist, input_shape, \
           lr_vector, epochs_vector, data_sets, generate, TRAIN_BY

### **Main**

In [25]:
def main():
    # Initialize Parameters
    batch_size, pic_width, prc_patterns, n_gray_levels, m_patterns, initial_lr, div_factor_lr, \
        num_dif_lr, n_epochs, train_samples_coco, test_samples_coco, train_samples_mnist, test_samples_mnist, \
        input_shape, lr_vector, epochs_vector, data_sets, generate, TRAIN_BY = get_run_parameters()

    run = wandb.init(
        project="DattNet",
        tags=[f"by{TRAIN_BY}"],

        # track hyperparameters and run metadata
        config={
            # Variables for BY_VECTORS run
            # "lr_vector": lr_vector,
            # "epochs_vec": epochs_vector,

            # Variables for BY_PARAMS run
            "init_lr": initial_lr,
            "num_of_lr": num_dif_lr,
            "div_f_lr": div_factor_lr,
            "epochs_for_lr": n_epochs,

            "batch_size": batch_size,
            "prc_patterns": prc_patterns,
            "train_num_coco": train_samples_coco,
            "train_num_mnist": train_samples_mnist
        })

    # Set sweep values for the hyperparameters
    # lr_vector = wandb.config.lr_vector
    # epochs_vector = wandb.config.epochs_vec
    initial_lr = wandb.config.init_lr
    num_dif_lr = wandb.config.num_of_lr
    div_factor_lr = wandb.config.div_f_lr
    n_epochs = wandb.config.epochs_for_lr
    batch_size = wandb.config.batch_size
    prc_patterns = wandb.config.prc_patterns
    train_samples_coco = wandb.config.train_num_coco
    train_samples_mnist = wandb.config.train_num_mnist
    
    dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    torch.cuda.set_device(dev)

    folder_path = make_folder('DattNet', train_samples_coco, train_samples_mnist, batch_size, n_gray_levels, prc_patterns)
    log_path = print_run_info_to_log(batch_size, pic_width, prc_patterns, n_gray_levels, m_patterns, initial_lr,
                                     div_factor_lr, num_dif_lr, n_epochs, train_samples_coco, train_samples_mnist,
                                     test_samples_coco, test_samples_mnist, lr_vector, epochs_vector, run.name, TRAIN_BY)
    
    try:   # get data
        train_loader, test_loader, data_time_message, patterns = get_data(generate, log_path, batch_size, pic_width, prc_patterns,
                                                                          m_patterns, train_samples_coco, train_samples_mnist,
                                                                          test_samples_coco, test_samples_mnist, n_gray_levels, data_sets)
        print_and_log_message(data_time_message, log_path)
        wandb.alert(title="Data Generated", text="Finished loading/generating the data successfully", level=wandb.AlertLevel.INFO)
    except Exception as Argument:
        print_and_log_message(Argument, log_path)
        wandb.alert(title="Failed to get the Data!", text=str(Argument), level=wandb.AlertLevel.ERROR)
        return      # stop the run after the error

    # define Model and Loss
    model = DattNet(pic_width, m_patterns, input_shape)
    model.to(dev)
    criterion = nn.MSELoss()
    model_path = folder_path + '/model.pth'
    
    storage_client = storage.Client()
    bucket = storage_client.get_bucket("our_train_test_data")   # set the bucket

    try:
        # Load past model if exists
        # if exists(model_path) and not generate:
        #     model.load_state_dict(torch.load(model_path))     # load saved model
        # else:

        if TRAIN_BY == "Vectors":
            model = train_by_vectors(lr_vector, epochs_vector, train_loader, criterion, batch_size,
                                     pic_width, log_path, n_gray_levels, folder_path, model)
        else:
            model = train_by_parameters(initial_lr, num_dif_lr, div_factor_lr,  train_loader, criterion, batch_size,
                                        n_epochs, pic_width, log_path, n_gray_levels, folder_path, model)

        # torch.save(model.state_dict(), model_path)    # TODO: fix saving Model to bucket instead of locally
    except Exception as Argument:
        print_and_log_message(Argument, log_path)
        wandb.alert(title="Run Crashed during Training!", text=str(Argument), level=wandb.AlertLevel.ERROR)
        # bucket.blob(log_path).upload_from_filename(log_path)   # update the log file in the end
        return     # stop the run after the error

    test_net(model, test_loader, criterion, batch_size, pic_width, log_path, n_gray_levels, folder_path)
    
    wandb.alert(title="Run Finished!", text="Miracles Happen :)", level=wandb.AlertLevel.INFO)
    wandb.finish()
    
    #bucket.blob(log_path).upload_from_filename(log_path)   # update the log file in the end, if saved to bucket

**Run main with if, or wandb agent for sweep:**

In [None]:
%%capture output

# if __name__ == "__main__":
#     main()

wandb.agent(sweep_id="noaavra/DattNet/x2w9ltj4", function=main, count=2)

In [None]:
output.show()