In [None]:
import os
import re
import time
from pathlib import Path
import numpy as np
import pandas as pd
import holoviews as hv
import hvplot.pandas
from holoviews import opts,dim
import matplotlib.pyplot as plt
import cv2
import random
from PIL import Image
from PIL import ImageFilter
from PIL import ImageEnhance
from IPython.display import display_png
opts.defaults(opts.Curve(width=600, framewise=True))
import logging
logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
%matplotlib inline
hv.extension('bokeh')
dims = [150,330]

### Data

In [None]:
labels = pd.read_csv('../data/label/labels_with_class.csv')

In [None]:
data_path = Path('../data')
file_names = [str(data_path/file)[:-4] for file in sorted(os.listdir(data_path)) if '.txt' in file]
rec_ids = [file.split('/')[-1] for file in file_names]

In [None]:
labels.step1 = labels.pH
labels.step2 = labels.pH
labels.step3 = labels.pH
labels.step4 = labels.pH

### Basic

In [None]:
def ma(x, period, type_ ='simple'): 
    x = np.asarray(x)
    if type=='simple':
        weights = np.ones(period)
    else:
        weights = np.exp(np.linspace(-1., 0., period)) #expontial MA

    weights /= weights.sum()

    a =  np.convolve(x, weights, mode='full')[:len(x)]
    a[:period] = a[period]
    return a

def previous_points(data,from_point,num_of_points):
    return data[from_point-num_of_points:from_point]

def fill_mean(fhr,window_size):
    rolling_window = window_size
    signal = fhr.copy() # We don't want to modify the actual signal by mistake so copy
#     print(pct_zero(signal))
    start_pt = rolling_window+1
    corrected_fhr = signal.copy()
    for i,c in enumerate(signal[start_pt:]):
        if c == 0:
            corrected_fhr[i+start_pt] = np.mean(previous_points(corrected_fhr,i+start_pt,rolling_window))
    return corrected_fhr

In [None]:
class Record:
    def __init__(self, location):
        #'fhr_orig' is unmodified data, 'fhr' is zero-removed data
        self.rec_id = int(str(location).split('/')[-1])
        self.fhr,self.uc = self.read_signals(location) #fhr orignal,
        self.length = len(self.fhr)
        self.info,self.freq = self.read_info(location)      
        self.labels = self.get_labels()
        self.pos2stage = int(self.info['Pos'])         #beg of stage 2
        self.vline = hv.VLine(self.pos2stage).opts(color='black',line_dash = 'dotted',line_width = 1)
        #self.pH = self.info['pH']
        #self.delType = self.info['Deliv']                   #Delivery Type
        #self.fhr_pro,self.pos2stage_pro = self.preprocess_fhr()          #fhr processed ,pos2stage processed
        #self.vline_pro = hv.VLine(self.pos2stage_pro).opts(color='green',line_dash = 'dotted')
       
    def __str__(self):
        return str(self.rec_id)
    
    def __repr__(self):
        return self.__str__()
    
    def read_signals(self,location):
        fhr_l = []
        uc_l = []
        f = open(str(location)+'.txt') #open signal data file
        for line in f.readlines():
            line = line.split('\t')
            fhr_v = int(line[1]) #fhr_value
            uc_v = int(line[2])  #uc_value
            fhr_l.append(fhr_v)
            uc_l.append(uc_v)
        fhr = np.asarray(fhr_l)/100 #fhr:fhr array
        uc = np.asarray(uc_l)/100 #uc:uc_array
        f.close()
        return fhr,uc
    
    def read_info(self,location):
        info = dict()
        f = open(str(location)+'.hea') #open info file
        lines = [line for line in f.readlines()]
        for line in lines[7:]:
            if '#' in line:
                if re.search(r'\d+', line):
                    info[re.search(r'\w+', line).group()] = re.search(r"[-+]?\d*\.\d+|\d+", line).group()
        f.close()
        freq = int(lines[0].split(' ')[2]) #Frequency
        return info,freq
        
    def get_info_df(self):
        df = pd.DataFrame.from_dict(self.info,orient = 'index')
        return df
    
    def get_labels(self):
        rec_labels = labels[labels.rec_id == self.rec_id]
        return rec_labels.values[0][[1,2,3,6]] #step1, step2, step3 ,pH
    
    def preprocess_fhr(self):
        signal = self.fhr.copy()
        missing_count = 0
        corr_sig = []   #corrected_signal
        for i in signal:
            if i!= 0:
        #         print(i)
                corr_sig.append(i)
            else:
                missing_count+=1

        new_pos2stage = self.pos2stage - missing_count  #corrected pos2stage
        
        return np.asarray(corr_sig),new_pos2stage
    
    def plot_fhr(self,orignal = True):
        if orignal:
            fhr = self.fhr
            vline = self.vline
        else:
            fhr = self.fhr_pro
            vline = self.vline_pro
        hline_1  = hv.HLine(160,).opts(color='red',line_dash = 'dotted',width = 1,)
        hline_2  = hv.HLine(110,).opts(color='red',line_dash = 'dotted',width = 1)
        fhr_plot = hv.Curve(fhr,'Time','FHR')*vline * hline_1 * hline_2
        fhr_plot.opts(
            opts.Curve( height=400, width=700,xaxis=None, line_width=1, tools=['hover'],line_alpha = 1))
        return fhr_plot
       
    def plot_uc(self):
        uc_plot = hv.Curve(self.uc,'Time','UC')*self.vline#*self.vline
        uc_plot.opts(
            opts.Curve( height=250, width=700,xaxis=None, line_width=1, tools=['hover'],color = 'orange'))
        return uc_plot
    
    def plot_labels(self):
        label_plot = hv.Bars(self.labels)
        label_plot.opts(opts.Bars(height = 250, width = 250,color = 'green'))
        l_1 = hv.HLine(1).opts(color='blue',line_dash = 'dotted')
        l_2 = hv.HLine(2).opts(color='blue',line_dash = 'dotted')
        l_3 = hv.HLine(3).opts(color='blue',line_dash = 'dotted')
        return label_plot * l_1 * l_2 * l_3

### Save and Load

In [None]:
def load_signals(location, **kwargs):
    rec = Record(location)
    
    corrected_fhr = fill_mean(rec.fhr,window_size = 50)
    mean_fill = hv.Curve(corrected_fhr).opts(line_width =1,line_alpha = 0.9,color = '#1b92f9')
                                                   
    value_ma_fhr = ma(corrected_fhr, period = 100,type_ = 'simple')
    ma_fhr_plot = hv.Curve(value_ma_fhr,'FHR_ema').opts(line_width =1,color = 'black')
    
    fhr_upper  = hv.HLine(160,).opts(line_dash = 'dotted',line_width = 1,color='red',)
    fhr_lower  = hv.HLine(110,).opts(line_dash = 'dotted',line_width = 1,color='red',)
    
    uc_shift_down = 70
    uc_plot = hv.Curve(rec.uc - uc_shift_down,'Time','UC')*rec.vline#*self.vline
    uc_plot.opts(opts.Curve(xaxis=None, line_width=1, tools=['hover'],color = 'pink'))
    
    value_ma_uc = ma(rec.uc, period = 100,type_ = 'simple')
    ma_uc_plot = hv.Curve(value_ma_uc - uc_shift_down,'uc_ema').opts(line_width =1,color = 'black')
    
    uc_l1  = hv.HLine(20 - uc_shift_down,).opts(line_dash = 'dotted',line_width = 1,color='red',)
    uc_l2  = hv.HLine(60 - uc_shift_down,).opts(line_dash = 'dotted',line_width = 1,color='red',)
    uc_l3  = hv.HLine(100 - uc_shift_down,).opts(line_dash = 'dotted',line_width = 1,color='red',)

    layout =  mean_fill * ma_fhr_plot * uc_plot  * fhr_upper * fhr_lower * uc_l1 * uc_l2 * uc_l3#* ma_uc_plot
    layout.opts(height =400,width = 800)
    
    return layout

dmap = hv.DynamicMap(load_signals, kdims='FHR').redim.values(FHR=file_names)
# dmap

### Helper functions

In [None]:
renderer = hv.renderer("bokeh")
def load_signal(fhr,uc,dims):
    
    height,width = dims
#     corrected_fhr = fill_mean(fhr,window_size = 50)
    mean_fill = hv.Curve(fhr).opts(line_width =1.2,line_alpha = 0.9,color = 'black')


    fhr_upper  = hv.HLine(150,).opts(line_dash = 'dotted',line_width = 1,color='red',)
    fhr_upper2  = hv.HLine(180,).opts(line_dash = 'solid',line_width = 1,color='red',)
    fhr_middle  = hv.HLine(130,).opts(line_dash = 'dotted',line_width = 1,color='red',)
    fhr_lower  = hv.HLine(110,).opts(line_dash = 'solid',line_width = 1,color='red',)
    fhr_lower2  = hv.HLine(40,).opts(line_dash = 'solid',line_width = 1,color='red',)

    uc_shift_down = 20
    uc_plot = hv.Curve(uc - uc_shift_down,'Time','UC')
    uc_plot.opts(opts.Curve(xaxis=None, line_width=2, tools=['hover'],color = 'orange')) 
    #colors have been chosen carefully to keep both signals separate

#     value_ma_uc = ma(uc, period = 100,type_ = 'simple')

    uc_l1  = hv.HLine(10 - uc_shift_down,).opts(line_dash = 'dotted',line_width = 1,color='red',)
    uc_l2  = hv.HLine(30 - uc_shift_down,).opts(line_dash = 'dotted',line_width = 1,color='red',)

    layout =  mean_fill * uc_plot * fhr_upper * fhr_lower * uc_l1 * uc_l2 *fhr_middle*fhr_lower2*fhr_upper2 #* uc_l3# * ma_uc_plot
    layout.opts(height =height,width = width,xlim = (-20,(max(np.shape(fhr)[0],np.shape(uc)[0])+20)), ylim = (-30,np.max(fhr)+10))
    
    return layout

In [None]:
def divide_signal_c0(location,dims=dims):
    
    rec = Record(location) 
    length = rec.length -1
    plots = []
    fhr = rec.fhr
    uc = rec.uc
    
    fhr = fill_mean(fhr,window_size = 50)
    
    #------------- batch1
    div_starts = []
    div_ends = []
    counter = 0
    
    total_divs = int((len(fhr)-1)/800)
    #print(total_divs)
    while total_divs>counter:
        div = counter
        div_start = (counter*800)
        div_end = ((counter+1)*800)
        div_starts.append(div_start)
        div_ends.append(div_end)
        counter+=1
    
    if (int((len(fhr)-1))%800) >150:
        div_starts.append(int((len(fhr)-1)-800))
        div_starts.append(int((len(fhr)-1)))
                        
    
    for i,(s,e )in enumerate(zip(div_starts,div_ends)):
        fhr_g = fhr[s:e]
        uc_g = uc[s:e]
        if np.mean(fhr_g)>=40:
            plots.append(load_signal(fhr_g,uc_g,dims=dims))
            
        
    #-------------batch2
    
    div_starts = []
    div_ends = []
    counter = 0
    
    total_divs = int((len(fhr)-401)/800)
    #print(total_divs)
    while total_divs>counter:
        div = counter
        div_start = (counter*800) + 400
        div_end = ((counter+1)*800) + 400
        div_starts.append(div_start)
        div_ends.append(div_end)
        counter+=1
        
      
    for i,(s,e )in enumerate(zip(div_starts,div_ends)):
        fhr_g = fhr[s:e]
        uc_g = uc[s:e]
        if np.mean(fhr_g)>=40:
            plots.append(load_signal(fhr_g,uc_g,dims=dims))
            

    return plots

In [None]:
def divide_signal_c1(location,dims=dims):
    
    rec = Record(location) 
    length = rec.length -1
    plots = []
    fhr = rec.fhr
    uc = rec.uc
    
    fhr = fill_mean(fhr,window_size = 50)
    
    #------------- batch0
    div_starts = []
    div_ends = []
    counter = 0
    
    total_divs = int((len(fhr)-1)/800)
    #print(total_divs)
    while total_divs>counter:
        div = counter
        div_start = (counter*800)
        div_end = ((counter+1)*800)
        div_starts.append(div_start)
        div_ends.append(div_end)
        counter+=1
    
    if (int((len(fhr)-1))%800) >150:
        div_starts.append(int((len(fhr)-1)-800))
        div_starts.append(int((len(fhr)-1)))
                        
    
    for i,(s,e )in enumerate(zip(div_starts,div_ends)):
        fhr_g = fhr[s:e]
        uc_g = uc[s:e]
        if np.mean(fhr_g)>=20:
            plots.append(load_signal(fhr_g,uc_g,dims=dims))
            
        
    #-------------batch1
    
    div_starts = []
    div_ends = []
    counter = 0
    
    total_divs = int((len(fhr)-101)/800)
    #print(total_divs)
    while total_divs>counter:
        div = counter
        div_start = (counter*800) + 100
        div_end = ((counter+1)*800) + 100
        div_starts.append(div_start)
        div_ends.append(div_end)
        counter+=1
        
      
    for i,(s,e )in enumerate(zip(div_starts,div_ends)):
        fhr_g = fhr[s:e]
        uc_g = uc[s:e]
        if np.mean(fhr_g)>=20:
            plots.append(load_signal(fhr_g,uc_g,dims=dims))
            
    #-------------batch2
    
    div_starts = []
    div_ends = []
    counter = 0
    
    total_divs = int((len(fhr)-201)/800)
    #print(total_divs)
    while total_divs>counter:
        div = counter
        div_start = (counter*800) + 200
        div_end = ((counter+1)*800) + 200
        div_starts.append(div_start)
        div_ends.append(div_end)
        counter+=1
        
      
    for i,(s,e )in enumerate(zip(div_starts,div_ends)):
        fhr_g = fhr[s:e]
        uc_g = uc[s:e]
        if np.mean(fhr_g)>=20:
            plots.append(load_signal(fhr_g,uc_g,dims=dims))
            
    #-------------batch3
    
    div_starts = []
    div_ends = []
    counter = 0
    
    total_divs = int((len(fhr)-301)/800)
    #print(total_divs)
    while total_divs>counter:
        div = counter
        div_start = (counter*800) + 300
        div_end = ((counter+1)*800) + 300
        div_starts.append(div_start)
        div_ends.append(div_end)
        counter+=1
        
      
    for i,(s,e )in enumerate(zip(div_starts,div_ends)):
        fhr_g = fhr[s:e]
        uc_g = uc[s:e]
        if np.mean(fhr_g)>=20:
            plots.append(load_signal(fhr_g,uc_g,dims=dims))
    
    #-------------batch4
    
    div_starts = []
    div_ends = []
    counter = 0
    
    total_divs = int((len(fhr)-401)/800)
    #print(total_divs)
    while total_divs>counter:
        div = counter
        div_start = (counter*800) + 400
        div_end = ((counter+1)*800) + 400
        div_starts.append(div_start)
        div_ends.append(div_end)
        counter+=1
        
      
    for i,(s,e )in enumerate(zip(div_starts,div_ends)):
        fhr_g = fhr[s:e]
        uc_g = uc[s:e]
        if np.mean(fhr_g)>=20:
            plots.append(load_signal(fhr_g,uc_g,dims=dims))

    #-------------batch5
    
    div_starts = []
    div_ends = []
    counter = 0
    
    total_divs = int((len(fhr)-501)/800)
    #print(total_divs)
    while total_divs>counter:
        div = counter
        div_start = (counter*800) + 500
        div_end = ((counter+1)*800) + 500
        div_starts.append(div_start)
        div_ends.append(div_end)
        counter+=1
        
    for i,(s,e )in enumerate(zip(div_starts,div_ends)):
        fhr_g = fhr[s:e]
        uc_g = uc[s:e]
        if np.mean(fhr_g)>=20:
            plots.append(load_signal(fhr_g,uc_g,dims=dims))
   
    #-------------batch6
    
    div_starts = []
    div_ends = []
    counter = 0
    
    total_divs = int((len(fhr)-601)/800)
    #print(total_divs)
    while total_divs>counter:
        div = counter
        div_start = (counter*800) + 600
        div_end = ((counter+1)*800) + 600
        div_starts.append(div_start)
        div_ends.append(div_end)
        counter+=1

      
    for i,(s,e )in enumerate(zip(div_starts,div_ends)):
        fhr_g = fhr[s:e]
        uc_g = uc[s:e]
        if np.mean(fhr_g)>=20:
            plots.append(load_signal(fhr_g,uc_g,dims=dims))

    #-------------batch7
    
    div_starts = []
    div_ends = []
    counter = 0
    
    total_divs = int((len(fhr)-701)/800)
    #print(total_divs)
    while total_divs>counter:
        div = counter
        div_start = (counter*800) + 700
        div_end = ((counter+1)*800) + 700
        div_starts.append(div_start)
        div_ends.append(div_end)
        counter+=1

      
    for i,(s,e )in enumerate(zip(div_starts,div_ends)):
        fhr_g = fhr[s:e]
        uc_g = uc[s:e]
        if np.mean(fhr_g)>=20:
            plots.append(load_signal(fhr_g,uc_g,dims=dims))

    #-------------batch7
    
    div_starts = []
    div_ends = []
    counter = 0
    
    total_divs = int((len(fhr)-701)/800)
    #print(total_divs)
    while total_divs>counter:
        div = counter
        div_start = (counter*800) + 700
        div_end = ((counter+1)*800) + 700
        div_starts.append(div_start)
        div_ends.append(div_end)
        counter+=1

      
    for i,(s,e )in enumerate(zip(div_starts,div_ends)):
        fhr_g = fhr[s:e]
        uc_g = uc[s:e]
        if np.mean(fhr_g)>=20:
            plots.append(load_signal(fhr_g,uc_g,dims=dims))


    return plots

In [None]:
# plots = divide_signal_c1(file_names[0],)
# plots[-1]

In [None]:
def rotate_(image, angle, color = "#ffffff"):   #rotate fill with white
    im = image.rotate(angle,resample=Image.BICUBIC, expand=True)
    bg = Image.new("L", im.size, color)
    bg.paste(im, im)
    return bg

# def rotate_with_fill(image, angle,contrast =1.5, color = "#ffffff"): #main rotation fn
#     i = rotate_(image, angle, color)
#     i = i.resize(image.size)
#     enh = ImageEnhance.Contrast(i)
#     return enh.enhance(contrast)

def save_augmentations(img,save_path,rec_id,label,div): #14 augmentations/(part or step)
    j = 1
    img = img.crop((52,0,302,150)) #crop out unwanted portion
    img_save_path = Path(f'{save_path}/{rec_id}_label_{label}_div_{div}')
    img.save(str(img_save_path)+f'_{j}.png')
    j+=1
#     rotate_with_fill(img,np.random.randint(-6,6)).save(str(img_save_path)+f'_{j}.png')  #rotate
#     j+=1
#     rotate_with_fill(img,np.random.randint(2,5)).save(str(img_save_path)+f'_{j}.png') #Contour
#     j+=1
    enh = ImageEnhance.Brightness(img) #Brightness
    enh.enhance(1.1).rotate(np.random.randint(-6,6)).save(str(img_save_path)+f'_{j}.png')
    j+=1
#     enh = ImageEnhance.Contrast(img)  #Contrast
#     rotate_(enh.enhance(1),np.random.randint(-6,6)).save(str(img_save_path)+f'_{j}.png')
#     j+=1
#     enh = ImageEnhance.Contrast(img)  #Contrast
#     enh.enhance(5).rotate(np.random.randint(-6,6)).save(str(img_save_path)+f'_{j}.png')
#     j+=1 
    enh = ImageEnhance.Sharpness(img) #Sharpness
    enh.enhance(6).rotate(np.random.randint(-7,7)).save(str(img_save_path)+f'_{j}.png')
    j+=1
    enh = ImageEnhance.Sharpness(img) #Sharpness
    enh.enhance(3).rotate(np.random.randint(-6,6)).save(str(img_save_path)+f'_{j}.png')
    j+=1

### Save Images

In [None]:
os.makedirs(f"../data/image_data",exist_ok=True)
os.makedirs(f"../data/image_data/data",exist_ok=True)
os.makedirs(f"../data/image_data/records",exist_ok=True)
os.makedirs(f"../data/image_data/records/class_0",exist_ok=True)
os.makedirs(f"../data/image_data/records/class_1",exist_ok=True)
os.makedirs(f"../data/image_data/models",exist_ok=True)

os.makedirs(f"../data/image_data/data/train",exist_ok=True)
os.makedirs(f"../data/image_data/data/train/class_0",exist_ok=True)
os.makedirs(f"../data/image_data/data/train/class_1",exist_ok=True)


os.makedirs(f"../data/image_data/data/test",exist_ok=True)
os.makedirs(f"../data/image_data/data/test/class_0",exist_ok=True)
os.makedirs(f"../data/image_data/data/test/class_1",exist_ok=True)


In [None]:
class_0_ids = [str(f) for f in [int(name.split('/')[-1]) for name in file_names] if f in labels[labels['class'] == 0].rec_id.values]
class_0_files = [f'../data/{f}' for f in [int(name.split('/')[-1]) for name in file_names] if f in labels[labels['class'] == 0].rec_id.values]

class_1_ids = [str(f) for f in [int(name.split('/')[-1]) for name in file_names] if f in labels[labels['class'] == 1].rec_id.values]
class_1_files = [f'../data/{f}' for f in [int(name.split('/')[-1]) for name in file_names] if f in labels[labels['class'] == 1].rec_id.values]

len(class_0_files),len(class_1_files)

### Class_0 record generation

In [None]:
error_list = []
count = 0
for file in class_0_files:
    file_name = str(file).split('/')[-1]
    os.makedirs(f"../data/image_data/records/class_0/{file_name}",exist_ok=True)
    
    label = 0
    
    #save images
    try:
        graphs = divide_signal_c0(file)
        for i,graph in enumerate(graphs):
            graph_name = f'../data/image_data/records/class_0/{file_name}/{file_name}_label_{label}_div_{i}.png'
            hv.save(graph,graph_name,backend = 'bokeh')

    except:
        error_list.append((file,i))
        
    count+=1
    print(count)
    
error_list

### Class_1 record generation

In [None]:
error_list = []
count = 0
for file in class_1_files:
    file_name = str(file).split('/')[-1]
    os.makedirs(f"../data/image_data/records/class_1/{file_name}",exist_ok=True)
    
    label = 1
    
    #save images
    try:
        graphs = divide_signal_c1(file)
        for i,graph in enumerate(graphs):
            graph_name = f'../data/image_data/records/class_1/{file_name}/{file_name}_label_{label}_div_{i}.png'
            hv.save(graph,graph_name,backend = 'bokeh')

    except:
        error_list.append((file,i))
        
    count+=1
    print(count)

error_list

### Augment and separate according to label

In [None]:
int(508*0.2),int(44*0.2),

In [None]:
class_0_train_ids = sorted(class_0_ids[:-100])
class_0_test_ids = sorted(class_0_ids[-100:])

class_1_train_ids = sorted(class_1_ids[:-8])
class_1_test_ids = sorted(class_1_ids[-8:])

In [None]:
len(class_0_train_ids),len(class_0_test_ids),len(class_1_train_ids),len(class_1_test_ids),

#### Class0

In [None]:
record_path = Path('../data/image_data/records/class_0')

##### Class_0 Train Augment

In [None]:
image_data = Path('../data/image_data/data/train/')

for rec_id in class_0_train_ids:
    im_files = os.listdir(record_path/rec_id)
    for i, im_file in enumerate(im_files):
        label = 0
        save_path = Path(image_data/f'class_{label}')
        div = im_file.split('_')[-1].split('.')[0]
        img = Image.open(record_path/rec_id/im_file).convert("L") 
        save_augmentations(img,save_path,rec_id,label,div)

##### Class_0 Test Augment

In [None]:
image_data = Path('../data/image_data/data/test/')

for rec_id in class_0_test_ids:
    im_files = os.listdir(record_path/rec_id)
    for i, im_file in enumerate(im_files):
        label = 0
        save_path = Path(image_data/f'class_{label}')
        div = im_file.split('_')[-1].split('.')[0]
        img = Image.open(record_path/rec_id/im_file).convert("L") 
        save_augmentations(img,save_path,rec_id,label,div)

#### Class1

In [None]:
record_path = Path('../data/image_data/records/class_1')

##### Class_1 Train Augment

In [None]:
image_data = Path('../data/image_data/data/train/')

for rec_id in class_1_train_ids:
    im_files = os.listdir(record_path/rec_id)
    for i, im_file in enumerate(im_files):
        label = 1
        save_path = Path(image_data/f'class_{label}')
        div = im_file.split('_')[-1].split('.')[0]
        img = Image.open(record_path/rec_id/im_file).convert("L") 
        save_augmentations(img,save_path,rec_id,label,div)

##### Class_1 Test Augment

In [None]:
image_data = Path('../data/image_data/data/test/')

for rec_id in class_1_test_ids:
    im_files = os.listdir(record_path/rec_id)
    for i, im_file in enumerate(im_files):
        label = 1
        save_path = Path(image_data/f'class_{label}')
        div = im_file.split('_')[-1].split('.')[0]
        img = Image.open(record_path/rec_id/im_file).convert("L") 
        save_augmentations(img,save_path,rec_id,label,div)

In [None]:
image_data = Path('../data/image_data/data/train/class_1/')
for f in os.listdir(image_data):
    img = Image.open(image_data/f)
    name1 = f.split('.')[0]+'a'+'.png'
    name2 = f.split('.')[0]+'b'+'.png'
    img.save(image_data/name1)
    img.save(image_data/name2)

In [None]:
image_data = Path('../data/image_data/data/test/class_1/')
for f in os.listdir(image_data):
    img = Image.open(image_data/f)
    name1 = f.split('.')[0]+'a'+'.png'
    name2 = f.split('.')[0]+'b'+'.png'
    img.save(image_data/name1)
    img.save(image_data/name2)

### Delete files from class_1

In [None]:
image_data = Path('../data/image_data/data/train/class_1/')
image_files = os.listdir(image_data)
a_files = [i for i in image_files if i.split('.')[-2][-1]=='b']

In [None]:
random.Random(4).shuffle(a_files),len(a_files)

In [None]:
a_files[0] 

In [None]:
for f in a_files[:200]:
    os.remove(image_data/f)

### Delete files from class_0

In [None]:
image_data = Path('../data/image_data/data/train/class_0/')
image_files = os.listdir(image_data)
random.Random(4).shuffle(image_files),len(image_files)

In [None]:
for f in image_files[:384]:
    os.remove(image_data/f)

In [None]:
len(os.listdir(Path('../data/image_data/data/train/class_0/')))

In [None]:
len(os.listdir(Path('../data/image_data/data/train/class_1/')))

In [None]:
len(os.listdir(Path('../data/image_data/data/test/class_0/')))

In [None]:
len(os.listdir(Path('../data/image_data/data/test/class_1/')))