# Loading Data 

In [1]:
import os
import numpy as np
import pandas as pd
import requests
import matplotlib.pyplot as plt

import PIL
from PIL import Image
from io import BytesIO

import numpy as np
import torch
from torchvision import transforms

In [4]:
def load_local_sample(location):
    x = np.load(location)
    x = rescale(x, 0.55, multichannel=False, anti_aliasing=True, mode='reflect')
    x = torch.from_numpy(x)
    return x

def load_html_sample(url):
    original_photo=requests.get(url)
    with Image.open(BytesIO(original_photo.content)) as img:
        data_img=np.array(img)
    return data_img

def img_cut(data):
    
    top_left_vertical = 60
    top_left_horizontal = 100
    height = 480
    width = 575

    if isinstance(data, PIL.Image.Image):
        numpydata=np.arrary(data)
        processed = Image.fromarray(numpydata[top_left_vertical:top_left_vertical+height, top_left_horizontal:top_left_horizontal+width])
    if isinstance(data, np.ndarray):
        processed = data[top_left_vertical:top_left_vertical+height, top_left_horizontal:top_left_horizontal+width]
    return processed

def gray_url(url):
    img=load_html(url)
    img=img_cut(img)
    r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
    return gray

def gray_scale(img):
    r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
    return gray

def normalize_array(img_array):
    tensor = torch.from_numpy(img_array).permute(2, 0, 1).float() / 255.0
    normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    normalized_tensor = normalize(tensor)
    normalized_array = normalized_tensor.permute(1, 2, 0).numpy()
    return normalized_array

def normalize_2d_array(arr):

    min_val = np.min(arr)
    max_val = np.max(arr)

    normalized_arr = (arr - min_val) / (max_val - min_val)
    
    return normalized_arr

# Augmentation Methods

List of Over Sampling Methods:
I. Geometric Augmentation
1. Random Crop
2. Artificial Noising
3. Lateral Inversion
4. Horizontal Inversion

II. Photometic Augmentation 
1. Image Blurring

III. Kernal 

IV. Mixing

V. Random Earasing

VI. SMOTE

Down Sampling Methods:
I. Random Down Sampling


In [3]:
import tqdm

### 看看之后dataset用什么数据类型然后来调整这段代码
def batch_augmentation(dataset: list, method):
    augmented_set=[]
    images = [x[0] for x in dataset]
    for sample in tqdm(images):
        sample=method(sample)
        augmented_set.append(sample)
    return augmented_set
    

In [4]:
import cv2

def random_crop(image, crop_size=(360,460)):
    img_size = image.shape
    channel_0, channel_1, channel_2 = image[:,:,0], image[:,:,1], image[:,:,2]
    idx_row = random.randint(0, img_size[0] - crop_size[0])
    idx_column = random.randint(0, img_size[0] - crop_size[0])
    channel_0 = channel_0[idx_row:idx_row + crop_size[0], 
                          idx_column:idx_column + crop_size[1]]
    channel_1 = channel_1[idx_row:idx_row + crop_size[0], 
                          idx_column:idx_column + crop_size[1]]
    channel_2 = channel_2[idx_row:idx_row + crop_size[0], 
                          idx_column:idx_column + crop_size[1]]
    image = np.dstack((channel_0, channel_1, channel_2))
    image = cv2.resize(image, img_size)
    return image

def img_noising(image, noise_intensity=0.2):
    img_shape=image.shape
    noise_threshold = 1 - noise_intensity
    channel_0, channel_1, channel_2 = image[:,:,0], image[:,:,1], image[:,:,2]

    channel_0 = channel_0.reshape(1024)
    channel_1 = channel_1.reshape(1024)
    channel_2 = channel_2.reshape(1024)

    noise_0 = np.zeros(1024, dtype='uint8')
    noise_1 = np.zeros(1024, dtype='uint8')
    noise_2 = np.zeros(1024, dtype='uint8')

    for idx in range(1024):
      regulator = round(random.random(), 1)
      if regulator > noise_threshold:
        noise_0[idx] = 255
        noise_1[idx] = 255
        noise_2[idx] = 255
      elif regulator == noise_threshold:
        noise_0[idx] = 0
        noise_1[idx] = 0
        noise_2[idx] = 0
      else:
        noise_0[idx] = channel_0[idx]
        noise_1[idx] = channel_1[idx]
        noise_2[idx] = channel_2[idx]

    noise_0 = noise_0.reshape(img_shape)
    noise_1 = noise_1.reshape(img_shape)
    noise_2 = noise_2.reshape(img_shape)

    image = np.dstack((noise_0, noise_1, noise_2))
    return image

def horizontal_flip(image):
    channel_0, channel_1, channel_2 = image[:,:,0], image[:,:,1], image[:,:,2]
    channel_0 = channel_0[:, ::-1]
    channel_1 = channel_1[:, ::-1]
    channel_2 = channel_2[:, ::-1]
    image = np.dstack((channel_0, channel_1, channel_2))
    return image

def vertical_flip(image):
    channel_0, channel_1, channel_2 = image[:,:,0], image[:,:,1], image[:,:,2]
    channel_0 = channel_0[::-1, :]
    channel_1 = channel_1[::-1, :]
    channel_2 = channel_2[::-1, :]
    image = np.dstack((channel_0, channel_1, channel_2))
    return image

def rotation(image,angle=90):
    channel_0, channel_1, channel_2 = image[:,:,0], image[:,:,1], image[:,:,2]

    if angle == 90:
        channel_0 = np.rot90(channel_0)
        channel_1 = np.rot90(channel_1)
        channel_2 = np.rot90(channel_2)
    elif angle == 180:
        channel_0 = np.rot90(channel_0, 2)
        channel_1 = np.rot90(channel_1, 2)
        channel_2 = np.rot90(channel_2, 2)
    elif angle == 270:
        channel_0 = np.rot90(channel_0, 3)
        channel_1 = np.rot90(channel_1, 3)
        channel_2 = np.rot90(channel_2, 3)

    image = np.dstack((channel_0, channel_1, channel_2))
    return image
    

In [5]:
def pad_image(image, padding=2):
    
    channel_0, channel_1, channel_2 = image[:,:,0], image[:,:,1], image[:,:,2]

    padded_0 = np.zeros((image.shape[0] + padding*2, 
                         image.shape[1] + padding*2), dtype='uint8')
    padded_1 = np.zeros((image.shape[0] + padding*2, 
                         image.shape[1] + padding*2), dtype='uint8')
    padded_2 = np.zeros((image.shape[0] + padding*2, 
                         image.shape[1] + padding*2), dtype='uint8')
    
    #  inserting image into zero array
    padded_0[int(padding):-int(padding), 
             int(padding):-int(padding)] = channel_0
    padded_1[int(padding):-int(padding), 
             int(padding):-int(padding)] = channel_1
    padded_2[int(padding):-int(padding), 
             int(padding):-int(padding)] = channel_2

    padded = np.dstack((padded_0, padded_1, padded_2))
    
    return padded

def blur_image(image, kernel_size=5, padding=2):

    gauss_5 = np.array([[1, 4, 7, 4, 1],
                     [4, 16, 26, 16, 4],
                     [7, 26, 41, 26, 7],
                     [4, 16, 26, 16, 4],
                     [1, 4, 7, 4, 1]])
    
    filter = 1/273 * gauss_5

    if padding>0:
      image = pad_image(image,padding=padding)
    else:
      image = image

    channel_0, channel_1, channel_2 = image[:,:,0], image[:,:,1], image[:,:,2]

    blurred_0 = np.zeros(((image.shape[0] - kernel_size) + 1, 
                          (image.shape[1] - kernel_size) + 1), dtype='uint8')
    blurred_1 = np.zeros(((image.shape[0] - kernel_size) + 1, 
                          (image.shape[1] - kernel_size) + 1), dtype='uint8')
    blurred_2 = np.zeros(((image.shape[0] - kernel_size) + 1, 
                          (image.shape[1] - kernel_size) + 1), dtype='uint8')
    
    for i in range(image.shape[0]):
        for j in range(image.shape[1]):
            try:
                blurred_0[i,j] = (channel_0[i:(i+kernel_size), j:(j+kernel_size)] * filter).sum()
            except Exception:
                pass

    for i in range(image.shape[0]):
        for j in range(image.shape[1]):
            try:
                blurred_1[i,j] = (channel_1[i:(i+kernel_size), j:(j+kernel_size)] * filter).sum()
            except Exception:
                pass

    for i in range(image.shape[0]):
        for j in range(image.shape[1]):
            try:
                blurred_2[i,j] = (channel_2[i:(i+kernel_size), j:(j+kernel_size)] * filter).sum()
            except Exception:
                pass

    blurred = np.dstack((blurred_0, blurred_1, blurred_2))
    
    return blurred

# Helper Functions for Training

In [6]:
class Accumulator:
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def evaluate_accuracy_gpu(net, data_iter, loss_func, device=None): 
    net.eval()
    metric = Accumulator(3)
    with torch.no_grad():
        for X, y in data_iter:
            X = X.to(device).to(torch.float)
            y = y.to(device).to(torch.long)
            y_hat = net(X)
            loss = loss_func(y_hat, y)
            metric.add(accuracy(y_hat, y), y.numel(), loss.sum())
    return metric[0] / metric[1], metric[2] / metric[1]

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    _, labels = torch.max(labels, dim=1)
    return (preds == labels).sum().item() / len(preds)


def save_model(epoch, model, optimizer, scheduler, checkpoint_dir, train_loss_history, filename):

    p = Path(checkpoint_dir)
    p.mkdir(parents=True, exist_ok=True)

    assert '.pt' in filename
    for f in [f for f in os.listdir(p) if '.pt' in f]:
        os.remove(p / f)

    np.save(p / 'train_loss_history_cnn', train_loss_history)

    output = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'optimizer_type': type(optimizer).__name__,
        'epoch': epoch,
    }

    if scheduler is not None:
        output['scheduler_state_dict'] = scheduler.state_dict()
        output['scheduler_type'] = type(scheduler).__name__
        
    torch.save(output, p / filename)


def load_net(checkpoint_dir=None):
    
    net = MyNet()
    
    if (checkpoint_dir is not None) and (Path(checkpoint_dir).is_dir()):
        p = Path(checkpoint_dir)
        files = [f for f in os.listdir(p) if '.pt' in f]
        
        if (files != []) and (len(files) == 1):
            checkpoint = torch.load(p / files[0])
            net.load_state_dict(checkpoint['model_state_dict'])
        
        epoch = checkpoint['epoch']
        
        train_loss_history = np.load(p / 'train_loss_history_cnn.npy').tolist()
        
        return net, epoch, train_loss_history
        
    else:
        return net, 0, []

import torch.optim as optim


def load_model(checkpointdir):

    net = MyNet()
    
    if checkpointdir is not None:
        p = Path(checkpointdir)
        if not p.is_dir():
            print('Checkpoint Error')
            return None
    
        files = [f for f in os.listdir(p) if f.endswith('.pt')]
        if not files:
            print('No model file found')
            return None
    
        checkpoint_path = p / files[0]
        checkpoint = torch.load(checkpoint_path)
    
        net.load_state_dict(checkpoint['model_state_dict'])
    
        optimizer_type = checkpoint.get('optimizer_type')
        if optimizer_type:
            optimizer = getattr(optim, optimizer_type)(net.parameters())
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        else:
            optimizer = None
    
        scheduler_type = checkpoint.get('scheduler_type')
        if scheduler_type:
            scheduler = getattr(optim.lr_scheduler, scheduler_type)(optimizer)
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        else:
            scheduler = None
    
        epoch = checkpoint['epoch']
        train_loss_history = checkpoint['train_loss_history']
    
        print('Load Successful')
        
        return net, optimizer, scheduler, epoch, train_loss_history

    else:
        return net, 0, []

# Data Loader Customized for Gravity Spy and Data Augmentation

In [None]:
'''
PyTorch defines the class Dataset and DataLoader in a way that loading data can be easily customized. We make use of this advantage.

The Gravity Spy dataset is recorded a number of csv files, in which accountst he URLs to each time-frequency spectrogram. The URLS 
alone takes up 1.1 G of memory. The images, original and later augmented, will take up more. 

We customize our Dataset and DataLoader class so that we can perform augmentations easily while save memory spaces. 
'''

'''
Keep the data as DataFrame and images and numpy array before inputting the model
'''

from torch.utils.data import DataLoader
from torch.utils.data import Dataset

class GS_Aug_Dataset(Dataset):
    def _init_(self,x,num_aug,list_aug):
        '''
        x: DataFrame
        Desired data structure: {index, label, url1, url2, url3, ulr4} 
        
        After agumentation: DataFrame
        Desired data structure: {index, label, [index to original image, [0,0,0,...,0] one_hot indication of augmentation type]}

        list_aug: all needed augmentation methods, its elements are "function" types
        '''

        self.augmented_x = x
        self.original_length = x.shape[0]
        self.num_aug = num_aug
        self.list_aug = list_aug

    #Dataset 和 DataLoader之间的协议是Dataset的_getitem_函数要返回 x, y
    def _getitem_(self,idx):
        
        content = self.augmented_x.iloc[idx,1]
        original_index = -1
        
        if isinstance(content, str):
            original_index = idx
            
        elif instance(content,list):
            aug_list = [0] * num_aug
            while True:
                origianl_index = content[0]
                content = self.augmented_x.iloc[original_index,1]
                if isinstance(content,str):
                    break
                else:
                    aug_list = [a + b for a, b in zip(aug_list, content[1])]
        
        url2 = self.augmented_x.iloc[original_index,2]
        url3 = self.augmented_x.iloc[original_index,3]
        url4 = self.augmented_x.iloc[original_index,4]

        for url in [content,url2,url3,url4]:
            url = gray_scale(img_cut(load_html_sample(url)))
            for i, method in enumerate(aug_list):
                if aug_list[i] == 1:
                    url = list_aug[i](url)

        return [content,url2,url3,url4], self.augmented_x.iloc[original_index,1]
    
    def _len_(self):
        return len(self.augmented_x)

    def _augment_(self, method:list, aug_on_aug = False,labels: list = ['All']):
        '''
        method: same as list_aug, a list of functions
        aug_on_aug: whether to do augmentation on what is already an augmentation
        labels: labels to which the augmentation is effected
        '''

        if aug_on_aug == False:
            if labels[0] == 'All':
                to_aug = self.augmented_x[0:self.original_length+1, 0:2]
                
            else:
                origianl_data = self.augmented_x[0:self.original_length+1]
                to_aug = pd.DataFrame()
                for label in labels:
                    original_data = original_data[:,0:2]
                    temp = original_data[self.x.iloc[:, 0] == label]
                    to_aug = pd.concat([to_aug, temp])
        
        elif aug_on_aug == True:
            if labels[0] == 'All':
                to_aug = self.augmented_x[:,0:2]
            
            else:
                to_aug = pd.DataFrame()
                for label in labels:
                    temp = self.augmented_x[self.x.iloc[:, 0] == label]
                    temp = temp[:,0:2]
                    to_aug = pd.concat([to_aug, temp])

        for method in methods:
                
            one_hot = [0] * len(list_aug)
            one_hot[self.list_aug.index(method)] = 1
        
            for i in range (0,to_aug.shape[0]+1):
                to_aug.iloc[i,1] = [int(to_aug.iloc[i].name),one_hot]
        
            self.augmented_x = pd.concat([self.augmented_x, to_aug], ignore_index = True)

class GS_Simple_Dataset1(Dataset):
    def _init_(self,x):
        '''
        x: DataFrame
        Desired data structure: {index, label, url1, url2, url3, ulr4} 
        '''
        self.x = x

    def _len_(self):
        return len(self.x)

    def _gititem_(self,idx):
        url1 = self.x.iloc[idx,1]
        url2 = self.x.iloc[idx,2]
        url3 = self.x.iloc[idx,3]
        url4 = self.x.iloc[idx,4]

        for url in [url1, url2, url3, url4]:
            url = gray_scale(img_cut(load_html_sample(url)))

        top_row = np.concatenate((url1, url2), axis=1)
        bottom_row = np.concatenate((url3, url4), axis=1)
        final_image = np.concatenate((top_row, bottom_row), axis=0)

        return final_image, self.x.iloc[idx,0]


class GS_Smple_Dataset2(torch.utils.data.Dataset):
    def __init__(self, data, label_col, url_cols, transform=None, max_workers=4):
        self.data = data.reset_index(drop=True)  # 重置索引
        self.label_col = label_col
        self.url_cols = url_cols
        self.transform = transform
        self.max_workers = max_workers
        self.executor = ThreadPoolExecutor(max_workers=max_workers)
        self.lock = threading.Lock()
        self.image_cache = {}  # 缓存加载的图像

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # 1. 获取基础数据
        with self.lock:
            label = self.data.loc[idx, self.label_col]
            urls = [self.data.loc[idx, col] for col in self.url_cols]

        # 2. 加载并处理图像（多线程加速）
        images = []
        with self.lock:
            for url in urls:
                if url in self.image_cache:
                    images.append(self.image_cache[url])
                    continue
                # 下载并处理图像
                try:
                    response = requests.get(url, timeout=10)
                    img = Image.open(BytesIO(response.content)).convert('RGB')
                    if self.transform:
                        img = self.transform(img)
                    images.append(img)
                    self.image_cache[url] = img  # 缓存结果
                except Exception as e:
                    print(f"Error loading {url}: {str(e)}")
                    lol = next(iter(self.image_cache.values()))
                    lol = np.array(lol)
                    lol = torch.from_numpy(lol)
                    images.append(torch.zeros_like(lol))  # 填充空值

        # 3. 图像拼接
        img1, img2, img3, img4 = images

        img1 = np.array(img1)
        img1 = torch.from_numpy(img1)
        img2 = np.array(img2)
        img2 = torch.from_numpy(img2)
        img3 = np.array(img3)
        img3 = torch.from_numpy(img3)
        img4 = np.array(img4)
        img4 = torch.from_numpy(img4)

        top_row = torch.cat((img1, img2), dim=2)  # 水平拼接
        bottom_row = torch.cat((img3, img4), dim=2)
        final_image = torch.cat((top_row, bottom_row), dim=1)  # 垂直拼接

        # 4. 数据增强
        if self.transform:
            final_image = self.transform(final_image)

        return final_image, label

class GS_Simple_Dataset3(Dataset):
    def __init__(self, x: pd.DataFrame):
        # 强制用 0..N-1 做行索引，方便 .iloc 访问
        self.x = x.reset_index(drop=True)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        # 用 iloc 取到第 idx 行
        row = self.x.iloc[idx]

        # 确保列名一致：'url1','url2','url3','url4'
        def is_valid_url(url):
            try:
                response = requests.head(url, timeout=5)
                return response.status_code == 200
            except:
                return False

        urls = [row[f'url{i}'] for i in range(1, 5)]
        valid_urls = [url for url in urls if is_valid_url(url)]

        while len(valid_urls) < 4 and valid_urls:
            valid_urls.extend(valid_urls[:4-len(valid_urls)])

        imgs = []
        for i in range(4):
            if i < len(valid_urls):
                url = valid_urls[i]
                html = load_html_sample(url)
                cut = img_cut(html)
                gray = gray_scale(cut)
                imgs.append(gray)
            else:
                imgs.append(None)

        # 拼接四张图：左右拼 → 上下拼
        top    = np.concatenate((imgs[0], imgs[1]), axis=1)
        bottom = np.concatenate((imgs[2], imgs[3]), axis=1)
        final  = np.concatenate((top, bottom), axis=0)

        # 转 torch tensor，必要时指定 dtype、加上 channel 维度
        tensor = torch.from_numpy(final).float().unsqueeze(0)  # e.g. [1,H,W]

        # 同理取 label
        label = row['ml_label']
        # 如果你要做分类，通常还要转成 LongTensor：
        # label = torch.tensor(label, dtype=torch.long)

        return tensor, label