In [1]:
import sys
sys.path.append('../dataset')
import criteo_search2
from net import Net_embedding
import torch
from tqdm.auto import tqdm
from utils import AverageMeter
from Randomized import RandomizedLabelPrivacy

import numpy as np
import math

In [3]:
dataset = criteo_search2.CriteoSearchDataset("../data/" + 'Criteo_Search.txt')
seed = 2024

In [4]:
train_length = int(len(dataset) * 0.8)
test_length = len(dataset) - train_length
train_dataset, test_dataset = torch.utils.data.random_split(
            dataset, (train_length, test_length), generator=torch.Generator().manual_seed(seed))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('train_dataset length: ', len(train_dataset), 'test_dataset length: ', len(test_dataset), 'device: ', device)

train_dataset length:  1316781 test_dataset length:  329196 device:  cuda


In [None]:
def compute_optimal_interval2(interval_freq, node, epsilon, delta):
    
    # Step 3: RPWithPrior i.e. Algorithm 1 in paper
    k = len(interval_freq)
    fmax = 0 # max value of f
    for i in range(k):
        for j in range(i+1, k):
            h = interval_freq[i] * node[i+1] + \
                torch.sum(interval_freq[i+1:j] * (node[i+2:j+1] - node[i+1:j])) - \
                    interval_freq[j] * node[j]
            c1 = 2 * delta * interval_freq[i] - math.exp(-epsilon) *  h
            slope = math.exp(-epsilon) * (interval_freq[j] - interval_freq[i])
            
            d11 = slope * node[j] -c1
            d12 = slope * node[j+1] - c1
            
            c2 = 2 * delta * interval_freq[j] - math.exp(-epsilon) * h
            
            d21 = -slope * node[i] + c2
            d22 = -slope * node[i+1] + c2
            e1 = c1 / slope
            e2 = c2 / slope
            
            A1max = node[i]
            A2max = node[j]
            h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                2 * delta+math.exp(-epsilon)*(A2max - A1max) )
            if fmax < h1:
                fmax = h1
                A1 = A1max
                A2 = A2max
            
            # (n_i,n_{j+1})
            A2max = node[j+1]
            h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                2 * delta+math.exp(-epsilon)*(A2max - A1max) )
            if fmax < h1:
                fmax = h1
                A1 = A1max
                A2 = A2max
                
            # (n_{i+1},n_{j+1})
            A1max = node[i+1]
            h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                2 * delta+math.exp(-epsilon)*(A2max - A1max) )
            if fmax < h1:
                fmax = h1
                A1 = A1max
                A2 = A2max 
                
            # (n_{i+1},n_j)
            A2max = node[j]
            h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                2 * delta+math.exp(-epsilon)*(A2max - A1max) )
            if fmax < h1:
                fmax = h1
                A1 = A1max
                A2 = A2max
                
                
            if d21 * d22 < 0:
                # (e_2,n_j)
                A1max = e2
                A2max = node[j]
                
                h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                    2 * delta+math.exp(-epsilon)*(A2max - A1max) )
                if fmax < h1:
                    fmax = h1
                    A1 = A1max
                    A2 = A2max
                
                # (e_2,n_{j+1})
                A2max = node[j+1]
                
                h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                    2 * delta+math.exp(-epsilon)*(A2max - A1max) )
                if fmax < h1:
                    fmax = h1
                    A1 = A1max
                    A2 = A2max
            if d11 * d12 < 0:
                # (n_i,e_1)
                A1max = node[i]
                A2max = e1
                
                h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                    2 * delta+math.exp(-epsilon)*(A2max - A1max) )
                if fmax < h1:
                    fmax = h1
                    A1 = A1max
                    A2 = A2max
                
                # (n_{i+1}, e_1)   
                A1max = node[i+1]
                
                h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                    2 * delta+math.exp(-epsilon)*(A2max - A1max ) )
                if fmax < h1:
                    fmax = h1
                    A1 = A1max
                    A2 = A2max  
                    
            if  d11 * d12 < 0 and d21 * d22 < 0:  
                # (e_2,e_1)
                A1max = e2
                A2max = e1      
                h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                    2 * delta+math.exp(-epsilon)*(A2max - A1max ))
                if fmax < h1:
                    A1 = A1max
                    A2 = A2max
    return A1, A2

In [None]:
def RPWithPrior3(train_loader, device, epsilon_total=0.1, delta=0.1):
    mechanism = "Laplace"
    # mechanism = "Gaussian"
    # mechanism = "staircase"

    epsilon1 = 0.01
    epsilon = epsilon_total - epsilon1

    rlp = RandomizedLabelPrivacy(epsilon1, mechanism, sensitivity=400, device=device)
    sensitivity = 400
    
    for i, (x, z, y) in enumerate(train_loader):
        x, z, y = x.to(device), z.to(device), y.to(device)
        target = y + rlp.noise(y.shape) 

        if i == 0:
            x_sets = x
            z_sets = z
            y_sets = y
            target_sets = target 
        else:
            x_sets = torch.cat((x_sets, x), 0)
            z_sets = torch.cat((z_sets, z), 0)
            y_sets = torch.cat((y_sets, y), 0)
            target_sets = torch.cat((y_sets, target), 0)
            
        # target_sets = y_sets
    target_sets = torch.max(target_sets, torch.zeros(target_sets.shape).to(device))
    # calculate the statistics of prior
    target_mean = target_sets.mean()
    target_std = target_sets.std() 
    
    # Step 2: calculate the histogram of prior 
    # calculate the value in each interval of the histogram
    k0 = ((torch.min(target_sets) - target_mean) / target_std ).floor().int().item()
    k1 = ((torch.max(target_sets) - target_mean) / target_std ).ceil().int().item()
    k = k1 - k0

    node = torch.zeros(k+1) # node in paper x_0...x_k
    interval_freq = torch.zeros(k) # value in each interval for histogram

    # calculate the relative frequency(probability) of each interval
    for i in range(k0, k1):
        if i == k0:
            node[i-k0] = torch.min(target_sets)
        else: 
            node[i-k0] = target_mean + i * target_std
        if i < k1 - 1:
            in_range = (target_sets - target_mean >= i * target_std) & \
                    (target_sets - target_mean < (i + 1) * target_std)
        else:
            in_range = (target_sets - target_mean >= i * target_std) & \
                    (target_sets - target_mean <= (i + 1) * target_std)
        interval_freq[i-k0] = in_range.sum().item()
    node[k] = torch.max(target_sets) 
    interval_freq = interval_freq / len(target_sets)
    
    # Step 3: RPWithPrior i.e. Algorithm 1 in this paper
    A1, A2 = compute_optimal_interval2(interval_freq, node, epsilon1, delta)
    while (A2 - A1 < 2 * delta):
        print('test')
        delta = (A2 - A1) / 2
        A1, A2 = compute_optimal_interval2(interval_freq, node, epsilon1, delta)
    print(torch.min(y_sets),interval_freq, A1, A2, torch.max(y_sets),(y_sets<A1).sum()/len(y_sets), (y_sets>A2).sum()/len(y_sets))
    
    # Step 4: add noise to target  ##### Algorithm 2 in this paper 
    # projection by Equation (3.6)  
    y_sets1 = y_sets.clone()   
    y_sets1[y_sets1 < A1] = A1 
    y_sets1[y_sets1 > A2] = A2

    rate = 1 / (math.exp(epsilon) *2 * delta + (A2 -A1))
    
    prob1 = (y_sets1 - A1) * rate 
    prob1[prob1 < 0] = 0
    prob2 = (A2 - y_sets1) * rate 
    prob2[prob2 < 0] = 0
    prob2 = 1- prob2
    
    new_label = 2 * torch.ones(len(y_sets1), dtype= int).to(device)
    random_tensor = torch.rand(len(y_sets1)).to(device)
    new_label[random_tensor - prob1 < 0] = 1
    new_label[random_tensor - prob2 > 0] = 3
    #############################
    y_tilde = y_sets1.clone()
    
    index = new_label == 1
    y_tilde[index] = A1 - delta + torch.rand(index.sum()).to(device) * torch.max(
        y_sets1[index] - A1,torch.zeros(index.sum()).to(device))
    index = new_label == 2
    y_tilde[index] = y_sets1[index] + delta * torch.empty_like(y_sets1[index]).uniform_(-1, 1).to(device)
    index = new_label == 3
    y_tilde[index] = A2 + delta - torch.rand(index.sum()).to(device) * torch.max(
        A2 - y_sets1[index],torch.zeros(index.sum()).to(device))
    
    return x_sets, z_sets, y_sets, y_tilde

In [None]:
epsilon = 1.5
delta = 70

model = Net_embedding(vocab_size=dataset.get_vocab()).to(device)

optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=5e-6)

loss_func = torch.nn.MSELoss()
epoch = 50
batch_size = 8192
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, 
      shuffle=True, num_workers=8, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, 
      shuffle=True, num_workers=6, pin_memory=True)

# Train                                      
x_sets, z_sets, y_sets, y_tilde = RPWithPrior3(train_loader, device, epsilon_total= epsilon, delta=delta)

# Train the model with the Label-DP dataset
labeldp_dataset = torch.utils.data.TensorDataset(x_sets.detach().cpu(),
                    z_sets.detach().cpu(), y_tilde.detach().cpu())
labeldp_loader = torch.utils.data.DataLoader(labeldp_dataset,
            batch_size=batch_size, shuffle=True, num_workers=6, pin_memory=True)
for i in tqdm(range(epoch)):
    losses = AverageMeter()
    for j, (x, z, y) in enumerate(labeldp_loader):
        x, z, y = x.to(device), z.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x, z)
        loss = loss_func(output.view(-1), y)
        loss.backward()
        optimizer.step()
        losses.update(loss.item(), x.shape[0])
    lr_scheduler.step()

    if i >=40: #i % 5 == 0 or i == epoch-1:
        train_loss = AverageMeter()
        test_loss = AverageMeter()
        with torch.no_grad():
            for x, z, y in train_loader:
                x, z, y = x.to(device), z.to(device), y.to(device)
                output = model(x, z)
                loss = loss_func(output.view(-1), y)
                train_loss.update(loss.item(), x.shape[0])
            for x, z, y in test_loader:
                x, z, y = x.to(device), z.to(device), y.to(device)
                output = model(x, z)
                loss = loss_func(output.view(-1), y)
                test_loss.update(loss.item(), x.shape[0])
        print("Epoch: {:>2}| Train Loss: {:.2f}| Train Loss: {:.2f}| Test Loss: {:.2f} ".format(i, 
                    losses.avg, train_loss.avg, test_loss.avg))


tensor(0., device='cuda:0') tensor([8.6593e-01, 1.3194e-01, 1.3305e-04, 1.2624e-04, 1.2322e-04, 1.0734e-04,
        1.0054e-04, 9.8273e-05, 9.1470e-05, 9.1470e-05, 7.6351e-05, 8.8446e-05,
        9.6761e-05, 7.3327e-05, 4.7625e-05, 4.6113e-05, 4.9137e-05, 5.3672e-05,
        5.1404e-05, 4.0821e-05, 4.6113e-05, 3.0994e-05, 2.4946e-05, 3.7041e-05,
        3.0994e-05, 2.6458e-05, 2.6458e-05, 3.0238e-05, 2.1922e-05, 2.7214e-05,
        2.1167e-05, 2.4190e-05, 1.8899e-05, 1.8899e-05, 1.4363e-05, 1.2851e-05,
        2.0411e-05, 1.1339e-05, 1.4363e-05, 9.8273e-06, 4.5357e-06, 5.2916e-06,
        1.2851e-05, 1.2851e-05, 6.0476e-06, 9.0714e-06, 1.0583e-05, 8.3154e-06,
        7.5595e-06, 3.7797e-06, 3.0238e-06, 6.0476e-06, 3.7797e-06, 7.5595e-06,
        3.7797e-06, 2.2678e-06, 3.7797e-06, 2.2678e-06, 3.7797e-06, 5.2916e-06,
        2.2678e-06, 3.0238e-06, 0.0000e+00, 2.2678e-06, 3.7797e-06, 7.5595e-07,
        2.2678e-06, 7.5595e-07, 3.0238e-06, 7.5595e-07, 0.0000e+00, 7.5595e-07,
        7.55

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch: 40| Train Loss: 5445.76| Train Loss: 4591.06| Test Loss: 4635.33 
Epoch: 41| Train Loss: 5417.76| Train Loss: 4611.13| Test Loss: 4653.94 
Epoch: 42| Train Loss: 5386.72| Train Loss: 4502.89| Test Loss: 4545.11 
Epoch: 43| Train Loss: 5366.51| Train Loss: 4618.40| Test Loss: 4659.40 
Epoch: 44| Train Loss: 5347.84| Train Loss: 4521.35| Test Loss: 4563.93 
Epoch: 45| Train Loss: 5333.97| Train Loss: 4697.40| Test Loss: 4739.28 
Epoch: 46| Train Loss: 5323.80| Train Loss: 4643.39| Test Loss: 4685.23 
Epoch: 47| Train Loss: 5318.22| Train Loss: 4485.46| Test Loss: 4527.50 
Epoch: 48| Train Loss: 5312.74| Train Loss: 4471.18| Test Loss: 4513.50 
Epoch: 49| Train Loss: 5310.14| Train Loss: 4543.63| Test Loss: 4585.27 
