In [14]:
import numpy as np
import torch

def create_conflict_matrix(number_of_person: int, number_of_conflict: int) -> np.ndarray:
    conflict_matrix = np.zeros([number_of_person, number_of_person])
    while sum(sum(conflict_matrix)) < number_of_conflict:
        first = np.random.randint(number_of_person)
        second = np.random.randint(number_of_person)
        if first != second and conflict_matrix[first][second] != 1:
            conflict_matrix[first][second] = 1
    return conflict_matrix

def describe_conflicts(conflict_matrix):
    num_people = conflict_matrix.shape[0]
    conflicts = []

    for i in range(num_people):
        for j in range(i+1, num_people):  # We only need to check the upper triangle
            if conflict_matrix[i, j] == 1:
                conflicts.append(f"There is conflict between person {i} and person {j}")

    return conflicts

conflict_matrix = create_conflict_matrix(24, 40)
conflict_matrix_tensor = torch.tensor(conflict_matrix, dtype=torch.float32, requires_grad=False)

# Get the conflict descriptions
conflict_descriptions = describe_conflicts(conflict_matrix)

print("Conflicts:")
for description in conflict_descriptions:
    print(description)

Conflicts:
There is conflict between person 0 and person 3
There is conflict between person 1 and person 13
There is conflict between person 2 and person 8
There is conflict between person 2 and person 21
There is conflict between person 2 and person 22
There is conflict between person 3 and person 4
There is conflict between person 3 and person 6
There is conflict between person 3 and person 14
There is conflict between person 4 and person 12
There is conflict between person 4 and person 15
There is conflict between person 5 and person 7
There is conflict between person 5 and person 17
There is conflict between person 7 and person 9
There is conflict between person 7 and person 12
There is conflict between person 8 and person 12
There is conflict between person 8 and person 19
There is conflict between person 10 and person 13
There is conflict between person 12 and person 17
There is conflict between person 13 and person 14
There is conflict between person 16 and person 18


In [15]:

def compute_cost(seating_probs, conflict_matrix, cost_rate, uniqueness_cost_rate, multiple_assignment_cost_rate, seating_shape):
    batch_size, num_seats, num_people = seating_probs.shape
    
    # Reshape seating probabilities
    seating_probs = seating_probs.view(batch_size, seating_shape[0], seating_shape[1], num_people)
    
    # Conflict cost
    conflict_costs = []
    for i in range(seating_shape[0]):
        for j in range(seating_shape[1]):
            person_probs = seating_probs[:, i, j, :]
            if j + 1 < seating_shape[1]:
                neighbor_probs = seating_probs[:, i, j+1, :]
                conflict_costs.append(torch.sum(torch.matmul(person_probs, torch.matmul(conflict_matrix, neighbor_probs.t()))))
            if i + 1 < seating_shape[0]:
                neighbor_probs = seating_probs[:, i+1, j, :]
                conflict_costs.append(torch.sum(torch.matmul(person_probs, torch.matmul(conflict_matrix, neighbor_probs.t()))))

    conflict_cost = torch.sum(torch.stack(conflict_costs)) * cost_rate

    # Uniqueness cost (encourage one person per seat)
    uniqueness_cost = uniqueness_cost_rate * torch.sum((torch.sum(seating_probs, dim=3) - 1).pow(2))
    
    # Multiple assignment cost (discourage one person in multiple seats)
    multiple_assignment_cost = multiple_assignment_cost_rate * torch.sum((torch.sum(seating_probs, dim=(1,2)) - 1).pow(2))

    total_cost = conflict_cost + uniqueness_cost + multiple_assignment_cost
    return total_cost



class SeatingNetwork(torch.nn.Module):
    def __init__(self, input_layer, hidden_layer, output_layer, seating_shape):
        super().__init__()
        self.input_layer = torch.nn.Linear(input_layer, hidden_layer)
        self.hidden_layer = torch.nn.Linear(hidden_layer, hidden_layer)
        self.output_layer = torch.nn.Linear(hidden_layer, output_layer)
        self.relu = torch.nn.ReLU()
        self.softmax = torch.nn.Softmax(dim=1)
        self.seating_shape = seating_shape

    def forward(self, x):
        x = self.relu(self.input_layer(x))
        x = self.relu(self.hidden_layer(x))
        x = self.output_layer(x)
        x = x.view(-1, self.seating_shape[0] * self.seating_shape[1], 24)
        x = self.softmax(x)
        return x

# Define input parameters
num_people = 24
input_size = num_people
hidden_size = 800
output_size = 24 * 24
cost_rate = 10
uniqueness_cost_rate = 100
multiple_assignment_cost_rate = 50
seating_shape = [6, 4]

# Initialize the neural network
model = SeatingNetwork(input_size, hidden_size, output_size, seating_shape)

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Example data
people = np.arange(num_people)

# Convert people to tensor
people_tensor = torch.tensor(people, dtype=torch.float32).unsqueeze(0)

# Training loop
for epoch in range(500):
    model.train()
    optimizer.zero_grad()
    output = model(people_tensor)
    
    loss = compute_cost(output, conflict_matrix_tensor, cost_rate, uniqueness_cost_rate, multiple_assignment_cost_rate, seating_shape)
    
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch [{epoch}], Loss: {loss.item()}")

# Example prediction
model.eval()
with torch.no_grad():
    output = model(people_tensor)
    seating_array = output.view(-1, *seating_shape, num_people).argmax(dim=3)
    print(seating_array)

Epoch [0/1000], Loss: 251.73248291015625
Epoch [10/1000], Loss: 57.83466720581055
Epoch [20/1000], Loss: 14.654181480407715
Epoch [30/1000], Loss: 1.9989800453186035
Epoch [40/1000], Loss: 0.5413824319839478
Epoch [50/1000], Loss: 0.24045085906982422
Epoch [60/1000], Loss: 0.13968193531036377
Epoch [70/1000], Loss: 0.19259995222091675
Epoch [80/1000], Loss: 0.09153880178928375
Epoch [90/1000], Loss: 0.05688200145959854
Epoch [100/1000], Loss: 0.053822167217731476
Epoch [110/1000], Loss: 0.0468902587890625
Epoch [120/1000], Loss: 0.04373352602124214
Epoch [130/1000], Loss: 0.04065030440688133
Epoch [140/1000], Loss: 0.0379326194524765
Epoch [150/1000], Loss: 0.03550676256418228
Epoch [160/1000], Loss: 0.03335242345929146
Epoch [170/1000], Loss: 0.03137677162885666
Epoch [180/1000], Loss: 0.2969515323638916
Epoch [190/1000], Loss: 0.03757844120264053
Epoch [200/1000], Loss: 0.027462469413876534
Epoch [210/1000], Loss: 0.02395297959446907
Epoch [220/1000], Loss: 0.021405640989542007
Epoch