In [1]:
import numpy as np
import ast
import random
import copy
from typing import List, Tuple, Union, Dict

class DataCorruption:
    def __init__(self):
        pass

    @staticmethod
    def read_data_from_file(file_path: str) -> List[Union[str, List[Union[str, int]]]]:
        """
        Read the data from a file and parse it into a list.
        """
        with open(file_path, 'r') as file:
            corruptions_string = file.read()
        data = ast.literal_eval(corruptions_string)
        return data

    def seperateitems(self, data: List[Union[str, List[Union[str, int]]]]) -> List[Union[str, List[Union[str, int]]]]:
        """
        Split the list so all items between a set of <T> tokens are individual.
        """
        items = []
        current_item = []

        for element in data:
            if element in ['<T>']:
                if current_item:
                    items.append(current_item)
                    current_item = []
                items.append(element)
            else:
                current_item.append(element)

        if current_item:
            items.append(current_item)

        return items

    def get_segment_to_corrupt(self, data: List[Union[str, List[Union[str, int]]]]) -> Tuple[List[int], List[Union[str, int]], bool]:
        """
        Get a random segment to corrupt from the data.
        """
        indices_to_corrupt = [n for n, i in enumerate(data) if type(i) == list]
        random_index = random.choice(indices_to_corrupt)
        last_idx_flag = False if random_index < len(data) - 1 else True

        return indices_to_corrupt, random_index, data[random_index], last_idx_flag

    def pitch_velocity_mask(self, data: List[Union[str, List[Union[str, int]]]]) -> Tuple[List[Union[str, List[Union[str, int]]]], str]:
        """
        Apply pitch and velocity mask to the data segment.
        """
        for n, note in enumerate(data):
            if type(data[n]) == list:
                data[n] = ['P', 'V', data[n][2], data[n][3]]

        return ['pitch_velocity_mask'] + data, 'pitch_velocity_mask'

    def onset_duration_mask(self, data: List[Union[str, List[Union[str, int]]]]) -> Tuple[List[Union[str, List[Union[str, int]]]], str]:
        """
        Apply onset and duration mask to the data segment.
        """
        for n, note in enumerate(data):
            if type(data[n]) == list:
                data[n] = [data[n][0], data[n][1], 'O', 'D']

        return ['onset_duration_mask'] + data, 'onset_duration_mask'

    def general_mask(self, data: List[Union[str, List[Union[str, int]]]]) -> Tuple[List[Union[str, List[Union[str, int]]]], str]:
        """
        Apply a general mask to the data segment.
        """
        str_elements = [i for i in data if type(i) == str]
        output = ['mask'] + str_elements

        return ['whole_mask'] + [output], 'whole_mask'

    def permute_pitches(self, data: List[Union[str, List[Union[str, int]]]]) -> Tuple[List[Union[str, List[Union[str, int]]]], str]:
        """
        Permute the pitches in the data segment.
        """
        pitches = [note[0] for note in data if type(note) == list]
        random.shuffle(pitches)

        for note in data:
            if type(note) == list:
                note[0] = pitches.pop(0)

        return ['pitch_permutation'] + data, 'pitch_permutation'

    def permute_pitch_velocity(self, data: List[Union[str, List[Union[str, int]]]]) -> Tuple[List[Union[str, List[Union[str, int]]]], str]:
        """
        Permute the pitches and velocities in the data segment.
        """
        pitches = [note[0] for note in data if type(note) == list]
        random.shuffle(pitches)

        velocities = [note[1] for note in data if type(note) == list]
        random.shuffle(velocities)

        for note in data:
            if type(note) == list:
                note[0] = pitches.pop(0)
                note[1] = velocities.pop(0)

        return ['pitch_velocity_permutation'] + data, 'pitch_velocity_permutation'

    def fragmentation(self, data: List[Union[str, List[Union[str, int]]]]) -> Tuple[List[Union[str, List[Union[str, int]]]], str]:
        """
        Fragment the data segment.
        """
        len_segment = len(data)
        # Choose a random percentage between 0.2-0.5 to fragment the data
        fragment_percentage = random.uniform(0.2, 0.5)
        fragment_length = int(len_segment * fragment_percentage)

        fragmented_data = []
        for n, note in enumerate(data):
            if type(note) == list:
                if n < fragment_length:
                    fragmented_data.append(note)
            else:
                fragmented_data.append(note)

        return ['fragmentation'] + fragmented_data, 'fragmentation'
    
    def incorrect_transposition(self, data: List[Union[str, List[Union[str, int]]]]) -> Tuple[List[Union[str, List[Union[str, int]]]], str]:
        """
        Transpose the pitches in the data segment by a random value.
        """
        add_by = 5
        subtract_by = -5
        for n, note in enumerate(data):
            if type(data[n]) == list:
                if random.choice([True, False]) and (data[n][0] < 127-add_by or data[n][0] > 0+subtract_by):
                    data[n][0] += random.randint(-5, 5)
            else:
                data[n] = data[n]

        return ['incorrect_transposition'] + data, 'incorrect_transposition'

    def apply_random_corruption(self, data: List[Union[str, List[Union[str, int]]]]) -> Dict[str, Union[List[Union[str, List[Union[str, int]]]], str, bool]]:
        """
        Apply a random corruption function to a segment of the data.
        """

        corruption_functions = [
            self.pitch_velocity_mask,
            self.onset_duration_mask,
            self.general_mask,
            self.permute_pitches,
            self.permute_pitch_velocity,
            self.fragmentation,
            self.incorrect_transposition
        ]
        corruption_function = random.choice(corruption_functions)

        corrupted_data = copy.deepcopy(data)
        corrupted_data = self.seperateitems(corrupted_data)
        all_segments, index, segment, last_idx_flag = self.get_segment_to_corrupt(corrupted_data)
        segment_copy = copy.deepcopy(segment)

        corrupted_segment, corruption_type = corruption_function(segment_copy)
        corrupted_data[index] = corrupted_segment

        # Concatenate back the corrupted data
        corrupted_data_sequence = []
        for element in corrupted_data:
            if type(element) == list:
                corrupted_data_sequence += element
            else:
                corrupted_data_sequence.append(element)

        output = {
            'Original Data': data,
            'Corrupted Data': corrupted_data_sequence,
            'Original Data Segment': segment,
            'Corrupted Data Segment': corrupted_segment,
            'Corruption Type': corruption_type,
            'Flag': last_idx_flag
        }

        return output


# Usage example
# Load the data from a file
data = DataCorruption.read_data_from_file('corruptions.txt')

# Initialize the DataCorruption class with the loaded data
data_corruption = DataCorruption()

# Apply a random corruption
output = data_corruption.apply_random_corruption(data)
output


{'Original Data': [[66, 60, 1000, 100],
  [64, 60, 1000, 130],
  [73, 75, 1000, 130],
  [48, 60, 1290, 130],
  [73, 75, 1290, 130],
  [66, 60, 1320, 60],
  [64, 60, 1320, 100],
  [66, 60, 1640, 480],
  [57, 60, 1640, 510],
  [63, 45, 1640, 510],
  [53, 45, 1640, 540],
  [71, 60, 2090, 930],
  [66, 45, 2120, 800],
  [56, 45, 2120, 830],
  [62, 45, 2120, 830],
  [46, 60, 2150, 1150],
  [73, 75, 2920, 350],
  [66, 60, 2920, 510],
  [49, 45, 2950, 100],
  [55, 45, 2950, 350],
  [61, 60, 2950, 350],
  [51, 45, 2950, 450],
  [73, 60, 3270, 160],
  [61, 60, 3300, 100],
  [49, 45, 3300, 130],
  [55, 60, 3300, 130],
  [67, 45, 3330, 100],
  [61, 45, 3590, 100],
  [66, 60, 3590, 100],
  [73, 75, 3590, 130],
  [55, 45, 3620, 60],
  [54, 45, 3940, 670],
  [60, 45, 3940, 670],
  [66, 60, 3940, 670],
  [44, 45, 3940, 700],
  [64, 60, 3940, 700],
  [67, 45, 4550, 640],
  [71, 60, 4550, 640],
  [62, 45, 4580, 130],
  [49, 45, 4580, 610],
  [59, 45, 4580, 610],
  [65, 45, 4580, 610],
  '<T>',
  [57, 45