In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import numpy as np
import torch

In [2]:
# Function to generate data
def generate_data(max_length, num_samples, p):
    data = []
    while len(data) < num_samples:
        n = np.random.randint(1, min((max_length + 2) // 3, 8))  # Limit n to ensure total length <= 20
        total_length = 3 * n
        if total_length > max_length:
            continue

        # Determine if the sample should be from the language or not based on probability p
        if np.random.rand() < p:
            sample = 'a' * n + 'b' * n + 'c' * n
            data.append((sample, 1))  # Label 1 for samples in the language
        else:
            sample = ''
            for _ in range(total_length):
                char = np.random.choice(['a', 'b', 'c'])
                sample += char
            data.append((sample, 0))  # Label 0 for non-language samples

    return data

#print(generate_data(20, 20,0.5))

In [3]:
class LanguageDataset(Dataset):
    def __init__(self, max_length, p, num_samples):
        self.max_length = max_length
        self.p = p
        self.num_samples = num_samples
        
        self.samples = self.generate_samples()
    
    def generate_samples(self):
        samples = []
        for _ in range(self.num_samples):
            length = np.random.randint(1, self.max_length + 1)
            sequence = self.generate_sequence(length)
            samples.append((sequence, int(self.is_language(sequence))))
        return samples
    
    def generate_sequence(self, length):
        sequence = ""
        for _ in range(length):
            choice = np.random.choice(['a', 'b', 'c'], p=self.p)
            sequence += choice
        return sequence
    
    def is_language(self, sequence):
        # Check if the sequence belongs to the language a^n b^n c^n
        counts = {'a': 0, 'b': 0, 'c': 0}
        for char in sequence:
            counts[char] += 1
        return counts['a'] == counts['b'] == counts['c']
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]

def generate_data(max_length, p, num_samples, batch_size=32):
    dataset = LanguageDataset(max_length, p, num_samples)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return data_loader

# Example usage:
max_length = 20
p = [0.3, 0.3, 0.4]  # Probability distribution for choosing characters
num_samples = 1000
batch_size = 32

data_loader = generate_data(max_length, p, num_samples, batch_size)

# Printing first batch as an example
for batch in data_loader:
    print(batch)
    break


[('c', 'baaab', 'abaaacccccbbbc', 'cacccccbbcaccaaabb', 'bbbacbbbc', 'cacacacbaaa', 'babc', 'cc', 'cab', 'c', 'aaacaccabacaacabaa', 'bc', 'abaaaac', 'ccbccccbacaccaabacba', 'caabcbccccbac', 'abaaccabbacbccccb', 'acabacbc', 'ac', 'bccabcbababcac', 'cbcbcacccbca', 'c', 'bbcacbbb', 'cbbabacbbbcbcba', 'abcccbaaac', 'cbbacccbca', 'acaabbcbabbcacabcabc', 'cbbababaabaacbbabca', 'bcbaccaca', 'cacbbccaaababcbcbc', 'bcbbaa', 'cccaa', 'aacbabccccbcaab'), tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])]
