Skip to content

Commit

Permalink
Add more unit tests.
Browse files Browse the repository at this point in the history
Complete all unit tests for values.py

Fix unit tests for calculating offsets in expressions.

Fix error with division in expressions.

Move IO and error printing out of lower level classes.

Fix for statement parsing - should be catching OperandTypeError

Catch and raise ParseError instead of OperandTypeError.

Catch all Exception errors and raise them during translation.

Add negative PCR 16-bit offset test.

Add unit tests for version 1 release.

Refactor virtual file support for easier testing.

Attempt to cover all of CassetteFile class.

Refactor cassette and binary files to be inherited container classes.

Move unit tests into proper packages.

Add unit tests for CoCoFile objects.

Add unit tests for SourceFile objects.

Update assembler so it can produce binary files.

Update file_util with new virtual file structures.

Add missing test for BinaryFile objects.
  • Loading branch information
craigthomas committed Sep 3, 2022
1 parent d614b9d commit 02031af
Show file tree
Hide file tree
Showing 26 changed files with 1,959 additions and 839 deletions.
66 changes: 49 additions & 17 deletions assembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
# I M P O R T S ###############################################################

import argparse
import sys

from cocoasm.exceptions import TranslationError, ParseError
from cocoasm.program import Program
from cocoasm.virtualfiles.virtualfile import CoCoFile
from cocoasm.virtualfiles.binary import BinaryFile
from cocoasm.virtualfiles.cassette import CassetteFile
from cocoasm.virtualfiles.virtual_file import VirtualFileType, VirtualFile
from cocoasm.virtualfiles.source_file import SourceFile, SourceFileType
from cocoasm.virtualfiles.coco_file import CoCoFile

# F U N C T I O N S ###########################################################

Expand Down Expand Up @@ -56,14 +58,34 @@ def parse_arguments():
return parser.parse_args()


def throw_error(error):
"""
Prints out an error message.
:param error: the error message to throw
"""
print(error.value)
print("{}".format(str(error.statement)))
sys.exit(1)


def main(args):
"""
Runs the assembler with the specified arguments.
:param args: the command-line arguments
"""
program = Program(width=args.width)
program.process(args.filename)
source_file = SourceFile(args.filename)
source_file.read_file()
program = Program()

try:
program.process(source_file.get_buffer())
except TranslationError as error:
throw_error(error)
except ParseError as error:
throw_error(error)

coco_file = CoCoFile(
name=program.name or args.name,
load_addr=program.origin,
Expand All @@ -72,18 +94,25 @@ def main(args):
)

if args.symbols:
program.print_symbol_table()
print("-- Symbol Table --")
for symbol in program.get_symbol_table():
print(symbol)

if args.print:
program.print_statements()
print("-- Assembled Statements --")
for statement in program.get_statements():
print(statement)

if args.bin_file:
try:
binary_file = BinaryFile()
binary_file.open_host_file_for_write(args.bin_file, append=args.append)
binary_file.save_to_host_file(coco_file)
binary_file.close_host_file()
except ValueError as error:
virtual_file = VirtualFile(
SourceFile(args.bin_file, file_type=SourceFileType.BINARY),
VirtualFileType.BINARY
)
virtual_file.open_virtual_file()
virtual_file.add_coco_file(coco_file)
virtual_file.save_virtual_file(append_mode=args.append)
except Exception as error:
print("Unable to save binary file:")
print(error)

Expand All @@ -92,11 +121,14 @@ def main(args):
print("No name for the program specified, not creating cassette file")
return
try:
cas_file = CassetteFile()
cas_file.open_host_file_for_write(args.cas_file, append=args.append)
cas_file.save_to_host_file(coco_file)
cas_file.close_host_file()
except ValueError as error:
virtual_file = VirtualFile(
SourceFile(args.cas_file, file_type=SourceFileType.BINARY),
VirtualFileType.CASSETTE
)
virtual_file.open_virtual_file()
virtual_file.add_coco_file(coco_file)
virtual_file.save_virtual_file(append_mode=args.append)
except Exception as error:
print("Unable to save cassette file:")
print(error)

Expand Down
24 changes: 11 additions & 13 deletions cocoasm/operands.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ def translate(self):
if registers[1] not in REGISTERS:
raise OperandTypeError("[{}] unknown register".format(registers[1]))

post_byte |= 0x00 if registers[0] == "D" else 0x00
post_byte |= 0x00 if registers[1] == "D" else 0x00
# Implicit is that a value of 0x00 means register D

post_byte |= 0x10 if registers[0] == "X" else 0x00
post_byte |= 0x01 if registers[1] == "X" else 0x00
Expand Down Expand Up @@ -471,7 +470,7 @@ def translate(self):
)
size = self.instruction.mode.ind_sz

if not type(self.value) == str and self.value.is_address():
if type(self.value) != str and self.value.is_address():
size += 2
return CodePackage(
op_code=NumericValue(self.instruction.mode.ind),
Expand All @@ -481,7 +480,7 @@ def translate(self):
max_size=size,
)

if not type(self.value) == str and self.value.is_numeric():
if type(self.value) != str and self.value.is_numeric():
size += 2
return CodePackage(
op_code=NumericValue(self.instruction.mode.ind),
Expand All @@ -507,7 +506,7 @@ def translate(self):
if "S" in self.right:
raw_post_byte |= 0x60

if self.left == "" or (not type(self.left) == str and self.left.is_numeric() and self.left.int == 0):
if self.left == "" or (type(self.left) != str and self.left.is_numeric() and self.left.int == 0):
if "-" in self.right or "+" in self.right:
if self.right == "X+" or self.right == "Y+" or self.right == "U+" or self.right == "S+":
raise OperandTypeError("[{}] not allowed as an extended indirect value".format(self.right))
Expand Down Expand Up @@ -603,13 +602,12 @@ def __init__(self, operand_string, instruction):
self.right = self.value.right

def resolve_symbols(self, symbol_table):
if self.left != "":
if self.left not in ["A", "B", "D"]:
self.left = Value.create_from_str(self.left, self.instruction, default_mode_extended=False)
if self.left.is_symbol():
self.left = self.left.resolve(symbol_table)
if self.left.is_address_expression() or self.left.is_expression():
self.left = self.left.resolve(symbol_table)
if self.left != "" and self.left not in ["A", "B", "D"]:
self.left = Value.create_from_str(self.left, self.instruction, default_mode_extended=False)
if self.left.is_symbol():
self.left = self.left.resolve(symbol_table)
if self.left.is_address_expression() or self.left.is_expression():
self.left = self.left.resolve(symbol_table)
return self

def translate(self):
Expand All @@ -634,7 +632,7 @@ def translate(self):
if "S" in self.right:
raw_post_byte |= 0x60

if self.left == "" or (not type(self.left) == str and self.left.is_numeric() and self.left.int == 0):
if self.left == "" or (type(self.left) != str and self.left.is_numeric() and self.left.int == 0):
raw_post_byte |= 0x80
if "-" in self.right or "+" in self.right:
if "+" in self.right:
Expand Down
154 changes: 62 additions & 92 deletions cocoasm/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
"""
# I M P O R T S ###############################################################

import sys

from cocoasm.exceptions import TranslationError, ParseError, ValueTypeError
from cocoasm.exceptions import TranslationError, ValueTypeError
from cocoasm.statement import Statement
from cocoasm.values import AddressValue, NoneValue
from cocoasm.virtualfiles.source_file import SourceFile

# C L A S S E S ###############################################################

Expand All @@ -21,59 +20,38 @@ class Program(object):
contains a list of statements. Additionally, a Program keeps track of all
the user-defined symbols in the program.
"""
def __init__(self, width=100):
def __init__(self):
self.symbol_table = dict()
self.statements = []
self.address = 0x0
self.origin = NoneValue()
self.name = None
self.width = width

def process(self, filename):
def process(self, source_file):
"""
Processes a filename for assembly.
:param filename: the name of the file to process
:param source_file: the source file to process
"""
try:
self.parse(filename)
self.translate_statements()
except TranslationError as error:
self.throw_error(error)
except ParseError as error:
self.throw_error(error)
self.statements = self.parse(source_file)
self.translate_statements()

def parse(self, filename):
@classmethod
def parse(cls, contents):
"""
Parses a single file and saves the set of statements.
:param filename: the name of the file to process
"""
self.statements = self.parse_file(filename)

def parse_file(self, filename):
"""
Parses all of the lines in a file, and transforms each line into
a Statement. Returns a list of all the statements in the file.
:param filename: the name of the file to parse
:param contents: a list of strings, each string represents one line of assembly
"""
statements = []
if not filename:
return statements

try:
with open(filename) as infile:
for line in infile:
statement = Statement(line)
if not statement.is_empty and not statement.is_comment_only:
statements.append(statement)
except ParseError as error:
self.throw_error(error)

for line in contents:
statement = Statement(line)
if not statement.is_empty and not statement.is_comment_only:
statements.append(statement)
return statements

def process_mnemonics(self, statements):
@classmethod
def process_mnemonics(cls, statements):
"""
Given a list of statements, processes the mnemonics on each statement, and
assigns each statement an Instruction object. If the statement is the
Expand All @@ -85,8 +63,14 @@ def process_mnemonics(self, statements):
"""
processed_statements = []
for statement in statements:
include = self.process_mnemonics(self.parse_file(statement.get_include_filename()))
processed_statements.extend(include if include else [statement])
include_filename = statement.get_include_filename()
if include_filename:
include_source = SourceFile(include_filename)
include_source.read_file()
include = cls.process_mnemonics(cls.parse(include_source.get_buffer()))
processed_statements.extend(include)
else:
processed_statements.extend([statement])
return processed_statements

def save_symbol(self, index, statement):
Expand All @@ -112,46 +96,40 @@ def translate_statements(self):
Translates all the parsed statements into their respective
opcodes.
"""
try:
self.statements = self.process_mnemonics(self.statements)
for index, statement in enumerate(self.statements):
self.save_symbol(index, statement)
self.statements = self.process_mnemonics(self.statements)
for index, statement in enumerate(self.statements):
self.save_symbol(index, statement)

for index, statement in enumerate(self.statements):
try:
statement.resolve_symbols(self.symbol_table)
except ValueTypeError as error:
raise TranslationError(str(error), statement)
for index, statement in enumerate(self.statements):
statement.resolve_symbols(self.symbol_table)

for index, statement in enumerate(self.statements):
statement.translate()

while not self.all_sizes_fixed():
for index, statement in enumerate(self.statements):
statement.translate()
if not statement.fixed_size:
statement.determine_pcr_relative_sizes(self.statements, index)

while not self.all_sizes_fixed():
for index, statement in enumerate(self.statements):
if not statement.fixed_size:
statement.determine_pcr_relative_sizes(self.statements, index)
address = 0
for index, statement in enumerate(self.statements):
address = statement.set_address(address)
address += statement.code_pkg.size

address = 0
for index, statement in enumerate(self.statements):
address = statement.set_address(address)
address += statement.code_pkg.size
for index, statement in enumerate(self.statements):
statement.fix_addresses(self.statements, index)

for index, statement in enumerate(self.statements):
statement.fix_addresses(self.statements, index)

# Update the symbol table with the proper addresses
for symbol, value in self.symbol_table.items():
if value.is_address():
self.symbol_table[symbol] = self.statements[value.int].code_pkg.address

# Find the origin and name of the project
for statement in self.statements:
if statement.instruction.is_origin:
self.origin = statement.code_pkg.address
if statement.instruction.is_name:
self.name = statement.operand.operand_string
except TranslationError as error:
self.throw_error(error)
# Update the symbol table with the proper addresses
for symbol, value in self.symbol_table.items():
if value.is_address():
self.symbol_table[symbol] = self.statements[value.int].code_pkg.address

# Find the origin and name of the project
for statement in self.statements:
if statement.instruction.is_origin:
self.origin = statement.code_pkg.address
if statement.instruction.is_name:
self.name = statement.operand.operand_string

def get_binary_array(self):
"""
Expand Down Expand Up @@ -187,30 +165,22 @@ def all_sizes_fixed(self):
return False
return True

def print_symbol_table(self):
def get_symbol_table(self):
"""
Prints out the symbol table and any values contained within it.
Returns a list of strings. Each string contains one entry from the symbol table.
"""
print("-- Symbol Table --")
lines = []
for symbol, value in self.symbol_table.items():
print("${} {}".format(value.hex().ljust(4, ' '), symbol))
lines.append("${} {}".format(value.hex().ljust(4, ' '), symbol))
return lines

def print_statements(self):
def get_statements(self):
"""
Prints out the assembled statements.
Returns a list of strings. Each string represents one assembled statement
"""
print("-- Assembled Statements --")
lines = []
for index, statement in enumerate(self.statements):
print("{}".format(str(statement).ljust(self.width)))

def throw_error(self, error):
"""
Prints out an error message.
:param error: the error message to throw
"""
print(error.value)
print("{}".format(str(error.statement).ljust(self.width)))
sys.exit(1)
lines.append("{}".format(str(statement)))
return lines

# E N D O F F I L E #######################################################

0 comments on commit 02031af

Please sign in to comment.